From 0f77687483e78bdb599fae90dbb2c3c13640afae Mon Sep 17 00:00:00 2001 From: arjunverma2004 <163740749+arjunverma2004@users.noreply.github.com> Date: Fri, 15 Aug 2025 14:41:32 +0530 Subject: [PATCH 1/8] [OpenVINO backend] Support numpy.diagonal issue 29115 --- .../openvino/excluded_concrete_tests.txt | 1 - keras/src/backend/openvino/numpy.py | 72 +++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/keras/src/backend/openvino/excluded_concrete_tests.txt b/keras/src/backend/openvino/excluded_concrete_tests.txt index b89a56260802..4ef116fcc6a4 100644 --- a/keras/src/backend/openvino/excluded_concrete_tests.txt +++ b/keras/src/backend/openvino/excluded_concrete_tests.txt @@ -86,7 +86,6 @@ NumpyOneInputOpsCorrectnessTest::test_corrcoef NumpyOneInputOpsCorrectnessTest::test_correlate NumpyOneInputOpsCorrectnessTest::test_cumprod NumpyOneInputOpsCorrectnessTest::test_diag -NumpyOneInputOpsCorrectnessTest::test_diagonal NumpyOneInputOpsCorrectnessTest::test_exp2 NumpyOneInputOpsCorrectnessTest::test_flip NumpyOneInputOpsCorrectnessTest::test_floor_divide diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index 8b3c5e0b5a96..345cae4470ad 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -15,6 +15,78 @@ from keras.src.backend.openvino.core import convert_to_tensor from keras.src.backend.openvino.core import get_ov_output from keras.src.backend.openvino.core import ov_to_keras_type +# --- Chnage for issue 29115 --- +import openvino.runtime.opset14 as ov + +from .core import OpenVINOKerasTensor # already present in file +from .core import _convert_to_node, _wrap_node # adapt if your file names differ + +def diagonal(x, offset=0, axis1=0, axis2=1): + """OpenVINO backend decomposition for keras.ops.diagonal.""" + x_node = _convert_to_node(x) # -> ov.Node + offset_const = ov.constant(int(offset), dtype="i64") + + # rank & normalize axes + shape = ov.shape_of(x_node) # i64 vector + rank = ov.shape_of(shape) # scalar i64 (len of shape) + rank_val = ov.squeeze(rank) # [] -> scalar + axis1_node = ov.mod(ov.add(ov.constant(int(axis1), dtype="i64"), rank_val), rank_val) + axis2_node = ov.mod(ov.add(ov.constant(int(axis2), dtype="i64"), rank_val), rank_val) + + # If axis1 == axis2, behavior should match numpy error; Keras tests don't hit this, + # so we skip explicit assert to keep graph-friendly. + + # Build permutation to move axis1, axis2 to the end + # perm = [all axes except axis1/axis2 in order] + [axis1, axis2] + arange = ov.range(ov.constant(0, dtype="i64"), rank_val, ov.constant(1, dtype="i64")) + mask1 = ov.equal(arange, axis1_node) + mask2 = ov.equal(arange, axis2_node) + not12 = ov.logical_not(ov.logical_or(mask1, mask2)) + others = ov.squeeze(ov.non_zero(not12), [1]) # gather positions != axis1, axis2 + perm = ov.concat([others, ov.reshape(axis1_node, [1]), ov.reshape(axis2_node, [1])], 0) + + x_perm = ov.transpose(x_node, perm) + permuted_shape = ov.shape_of(x_perm) + # last two dims + last2 = ov.gather(permuted_shape, ov.constant([-2, -1], dtype="i64"), ov.constant(0, dtype="i64")) + d1 = ov.gather(permuted_shape, ov.constant([-2], dtype="i64"), ov.constant(0, dtype="i64")) + d2 = ov.gather(permuted_shape, ov.constant([-1], dtype="i64"), ov.constant(0, dtype="i64")) + d1 = ov.squeeze(d1) # scalar + d2 = ov.squeeze(d2) # scalar + + # start1 = max(0, offset), start2 = max(0, -offset) + zero = ov.constant(0, dtype="i64") + start1 = ov.maximum(zero, offset_const) + start2 = ov.maximum(zero, ov.negative(offset_const)) + + # L = min(d1 - start1, d2 - start2) + l1 = ov.subtract(d1, start1) + l2 = ov.subtract(d2, start2) + L = ov.minimum(l1, l2) + + # r = range(0, L, 1) -> shape [L] + r = ov.range(zero, L, ov.constant(1, dtype="i64")) + idx_row = ov.add(r, start1) + idx_col = ov.add(r, start2) + idx_row = ov.unsqueeze(idx_row, ov.constant(1, dtype="i64")) # [L,1] + idx_col = ov.unsqueeze(idx_col, ov.constant(1, dtype="i64")) # [L,1] + diag_idx = ov.concat([idx_row, idx_col], 1) # [L,2] + + # Broadcast indices to batch dims: target shape = (*batch, L, 2) + # batch_rank = rank(x) - 2 + two = ov.constant(2, dtype="i64") + batch_rank = ov.subtract(rank_val, two) + # build target shape: concat(permuted_shape[:batch_rank], [L, 2]) + batch_shape = ov.slice(permuted_shape, ov.constant([0], dtype="i64"), + ov.reshape(batch_rank, [1]), ov.constant([1], dtype="i64")) + target_shape = ov.concat([batch_shape, ov.reshape(L, [1]), ov.constant([2], dtype="i64")], 0) + bcast_idx = ov.broadcast(diag_idx, target_shape) + + # GatherND with batch_dims = batch_rank + gathered = ov.gather_nd(x_perm, bcast_idx, batch_rank) + + return OpenVINOKerasTensor(gathered) + def add(x1, x2): From ea0a40fa6be565afc5eca5c3a6516e0b23efdcad Mon Sep 17 00:00:00 2001 From: arjunverma2004 <163740749+arjunverma2004@users.noreply.github.com> Date: Fri, 15 Aug 2025 14:58:24 +0530 Subject: [PATCH 2/8] gmni comit --- keras/src/backend/openvino/numpy.py | 114 +++++++++++++++++----------- 1 file changed, 69 insertions(+), 45 deletions(-) diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index 345cae4470ad..c4b70c253439 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -15,75 +15,99 @@ from keras.src.backend.openvino.core import convert_to_tensor from keras.src.backend.openvino.core import get_ov_output from keras.src.backend.openvino.core import ov_to_keras_type -# --- Chnage for issue 29115 --- -import openvino.runtime.opset14 as ov - -from .core import OpenVINOKerasTensor # already present in file -from .core import _convert_to_node, _wrap_node # adapt if your file names differ +from .core import _convert_to_node def diagonal(x, offset=0, axis1=0, axis2=1): """OpenVINO backend decomposition for keras.ops.diagonal.""" x_node = _convert_to_node(x) # -> ov.Node - offset_const = ov.constant(int(offset), dtype="i64") + offset_const = ov_opset.constant(int(offset), dtype="i64") # rank & normalize axes - shape = ov.shape_of(x_node) # i64 vector - rank = ov.shape_of(shape) # scalar i64 (len of shape) - rank_val = ov.squeeze(rank) # [] -> scalar - axis1_node = ov.mod(ov.add(ov.constant(int(axis1), dtype="i64"), rank_val), rank_val) - axis2_node = ov.mod(ov.add(ov.constant(int(axis2), dtype="i64"), rank_val), rank_val) + shape = ov_opset.shape_of(x_node) # i64 vector + rank = ov_opset.shape_of(shape) # scalar i64 (len of shape) + rank_val = ov_opset.squeeze(rank) # [] -> scalar + axis1_node = ov_opset.floor_mod( + ov_opset.add(ov_opset.constant(int(axis1), dtype="i64"), rank_val), rank_val + ) + axis2_node = ov_opset.floor_mod( + ov_opset.add(ov_opset.constant(int(axis2), dtype="i64"), rank_val), rank_val + ) # If axis1 == axis2, behavior should match numpy error; Keras tests don't hit this, # so we skip explicit assert to keep graph-friendly. # Build permutation to move axis1, axis2 to the end # perm = [all axes except axis1/axis2 in order] + [axis1, axis2] - arange = ov.range(ov.constant(0, dtype="i64"), rank_val, ov.constant(1, dtype="i64")) - mask1 = ov.equal(arange, axis1_node) - mask2 = ov.equal(arange, axis2_node) - not12 = ov.logical_not(ov.logical_or(mask1, mask2)) - others = ov.squeeze(ov.non_zero(not12), [1]) # gather positions != axis1, axis2 - perm = ov.concat([others, ov.reshape(axis1_node, [1]), ov.reshape(axis2_node, [1])], 0) - - x_perm = ov.transpose(x_node, perm) - permuted_shape = ov.shape_of(x_perm) - # last two dims - last2 = ov.gather(permuted_shape, ov.constant([-2, -1], dtype="i64"), ov.constant(0, dtype="i64")) - d1 = ov.gather(permuted_shape, ov.constant([-2], dtype="i64"), ov.constant(0, dtype="i64")) - d2 = ov.gather(permuted_shape, ov.constant([-1], dtype="i64"), ov.constant(0, dtype="i64")) - d1 = ov.squeeze(d1) # scalar - d2 = ov.squeeze(d2) # scalar + arange = ov_opset.range( + ov_opset.constant(0, dtype="i64"), rank_val, ov_opset.constant(1, dtype="i64") + ) + mask1 = ov_opset.equal(arange, axis1_node) + mask2 = ov_opset.equal(arange, axis2_node) + not12 = ov_opset.logical_not(ov_opset.logical_or(mask1, mask2)) + others = ov_opset.squeeze( + ov_opset.non_zero(not12), [1] + ) # gather positions != axis1, axis2 + perm = ov_opset.concat( + [others, ov_opset.reshape(axis1_node, [1]), ov_opset.reshape(axis2_node, [1])], 0 + ) + + x_perm = ov_opset.transpose(x_node, perm) + permuted_shape = ov_opset.shape_of(x_perm) + d1 = ov_opset.gather( + permuted_shape, + ov_opset.constant([-2], dtype="i64"), + ov_opset.constant(0, dtype="i64"), + ) + d2 = ov_opset.gather( + permuted_shape, + ov_opset.constant([-1], dtype="i64"), + ov_opset.constant(0, dtype="i64"), + ) + d1 = ov_opset.squeeze(d1) # scalar + d2 = ov_opset.squeeze(d2) # scalar # start1 = max(0, offset), start2 = max(0, -offset) - zero = ov.constant(0, dtype="i64") - start1 = ov.maximum(zero, offset_const) - start2 = ov.maximum(zero, ov.negative(offset_const)) + zero = ov_opset.constant(0, dtype="i64") + start1 = ov_opset.maximum(zero, offset_const) + start2 = ov_opset.maximum(zero, ov_opset.negative(offset_const)) # L = min(d1 - start1, d2 - start2) - l1 = ov.subtract(d1, start1) - l2 = ov.subtract(d2, start2) - L = ov.minimum(l1, l2) + l1 = ov_opset.subtract(d1, start1) + l2 = ov_opset.subtract(d2, start2) + L = ov_opset.minimum(l1, l2) # r = range(0, L, 1) -> shape [L] - r = ov.range(zero, L, ov.constant(1, dtype="i64")) - idx_row = ov.add(r, start1) - idx_col = ov.add(r, start2) - idx_row = ov.unsqueeze(idx_row, ov.constant(1, dtype="i64")) # [L,1] - idx_col = ov.unsqueeze(idx_col, ov.constant(1, dtype="i64")) # [L,1] - diag_idx = ov.concat([idx_row, idx_col], 1) # [L,2] + r = ov_opset.range(zero, L, ov_opset.constant(1, dtype="i64")) + idx_row = ov_opset.add(r, start1) + idx_col = ov_opset.add(r, start2) + idx_row = ov_opset.unsqueeze( + idx_row, ov_opset.constant(1, dtype="i64") + ) # [L,1] + idx_col = ov_opset.unsqueeze( + idx_col, ov_opset.constant(1, dtype="i64") + ) # [L,1] + diag_idx = ov_opset.concat([idx_row, idx_col], 1) # [L,2] # Broadcast indices to batch dims: target shape = (*batch, L, 2) # batch_rank = rank(x) - 2 - two = ov.constant(2, dtype="i64") - batch_rank = ov.subtract(rank_val, two) + two = ov_opset.constant(2, dtype="i64") + batch_rank = ov_opset.subtract(rank_val, two) # build target shape: concat(permuted_shape[:batch_rank], [L, 2]) - batch_shape = ov.slice(permuted_shape, ov.constant([0], dtype="i64"), - ov.reshape(batch_rank, [1]), ov.constant([1], dtype="i64")) - target_shape = ov.concat([batch_shape, ov.reshape(L, [1]), ov.constant([2], dtype="i64")], 0) - bcast_idx = ov.broadcast(diag_idx, target_shape) + batch_shape = ov_opset.strided_slice( + permuted_shape, + begin=ov_opset.constant([0], dtype="i64"), + end=ov_opset.reshape(batch_rank, [1]), + strides=ov_opset.constant([1], dtype="i64"), + begin_mask=[0], + end_mask=[0], + ) + target_shape = ov_opset.concat( + [batch_shape, ov_opset.reshape(L, [1]), ov_opset.constant([2], dtype="i64")], 0 + ) + bcast_idx = ov_opset.broadcast(diag_idx, target_shape) # GatherND with batch_dims = batch_rank - gathered = ov.gather_nd(x_perm, bcast_idx, batch_rank) + gathered = ov_opset.gather_nd(x_perm, bcast_idx, batch_rank) return OpenVINOKerasTensor(gathered) From df028d02c1381a6f0caa6362048600e54b774b91 Mon Sep 17 00:00:00 2001 From: arjunverma2004 <163740749+arjunverma2004@users.noreply.github.com> Date: Fri, 15 Aug 2025 16:44:31 +0530 Subject: [PATCH 3/8] fixxx --- keras/src/backend/openvino/numpy.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index c4b70c253439..71bd78775068 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -33,11 +33,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1): ov_opset.add(ov_opset.constant(int(axis2), dtype="i64"), rank_val), rank_val ) - # If axis1 == axis2, behavior should match numpy error; Keras tests don't hit this, - # so we skip explicit assert to keep graph-friendly. - # Build permutation to move axis1, axis2 to the end - # perm = [all axes except axis1/axis2 in order] + [axis1, axis2] arange = ov_opset.range( ov_opset.constant(0, dtype="i64"), rank_val, ov_opset.constant(1, dtype="i64") ) @@ -773,11 +769,6 @@ def diag(x, k=0): raise NotImplementedError("`diag` is not supported with openvino backend") -def diagonal(x, offset=0, axis1=0, axis2=1): - raise NotImplementedError( - "`diagonal` is not supported with openvino backend" - ) - def diff(a, n=1, axis=-1): if n == 0: From 1635157e59afb70d371cce2bf5bb37879d81466b Mon Sep 17 00:00:00 2001 From: arjunverma2004 <163740749+arjunverma2004@users.noreply.github.com> Date: Fri, 15 Aug 2025 16:48:59 +0530 Subject: [PATCH 4/8] convertcde --- keras/src/backend/openvino/numpy.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index 71bd78775068..86b7c7fcae9f 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -15,11 +15,14 @@ from keras.src.backend.openvino.core import convert_to_tensor from keras.src.backend.openvino.core import get_ov_output from keras.src.backend.openvino.core import ov_to_keras_type -from .core import _convert_to_node +import numpy as np +from openvino.runtime import opset13 as ov + + def diagonal(x, offset=0, axis1=0, axis2=1): """OpenVINO backend decomposition for keras.ops.diagonal.""" - x_node = _convert_to_node(x) # -> ov.Node + x_node = ov.constant(x) # -> ov.Node offset_const = ov_opset.constant(int(offset), dtype="i64") # rank & normalize axes From 96802c5987c8c317aa856411a8bbbcc043ec4680 Mon Sep 17 00:00:00 2001 From: arjunverma2004 <163740749+arjunverma2004@users.noreply.github.com> Date: Fri, 15 Aug 2025 17:31:37 +0530 Subject: [PATCH 5/8] Update numpy.py --- keras/src/backend/openvino/numpy.py | 37 ++++++++++++++++------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index 86b7c7fcae9f..a9ccc1c094ce 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -1,6 +1,7 @@ import numpy as np import openvino.opset14 as ov_opset from openvino import Type +from openvino.runtime import opset13 as ov from keras.src.backend import config from keras.src.backend.common import dtypes @@ -15,13 +16,9 @@ from keras.src.backend.openvino.core import convert_to_tensor from keras.src.backend.openvino.core import get_ov_output from keras.src.backend.openvino.core import ov_to_keras_type -import numpy as np -from openvino.runtime import opset13 as ov - def diagonal(x, offset=0, axis1=0, axis2=1): - """OpenVINO backend decomposition for keras.ops.diagonal.""" x_node = ov.constant(x) # -> ov.Node offset_const = ov_opset.constant(int(offset), dtype="i64") @@ -30,15 +27,17 @@ def diagonal(x, offset=0, axis1=0, axis2=1): rank = ov_opset.shape_of(shape) # scalar i64 (len of shape) rank_val = ov_opset.squeeze(rank) # [] -> scalar axis1_node = ov_opset.floor_mod( - ov_opset.add(ov_opset.constant(int(axis1), dtype="i64"), rank_val), rank_val + ov_opset.add(ov_opset.constant(int(axis1), dtype="i64"), rank_val), + rank_val, ) axis2_node = ov_opset.floor_mod( - ov_opset.add(ov_opset.constant(int(axis2), dtype="i64"), rank_val), rank_val + ov_opset.add(ov_opset.constant(int(axis2), dtype="i64"), rank_val), + rank_val, ) - - arange = ov_opset.range( - ov_opset.constant(0, dtype="i64"), rank_val, ov_opset.constant(1, dtype="i64") + ov_opset.constant(0, dtype="i64"), + rank_val, + ov_opset.constant(1, dtype="i64"), ) mask1 = ov_opset.equal(arange, axis1_node) mask2 = ov_opset.equal(arange, axis2_node) @@ -47,9 +46,14 @@ def diagonal(x, offset=0, axis1=0, axis2=1): ov_opset.non_zero(not12), [1] ) # gather positions != axis1, axis2 perm = ov_opset.concat( - [others, ov_opset.reshape(axis1_node, [1]), ov_opset.reshape(axis2_node, [1])], 0 + [ + others, + ov_opset.reshape(axis1_node, [1]), + ov_opset.reshape(axis2_node, [1]), + ], + 0, ) - + x_perm = ov_opset.transpose(x_node, perm) permuted_shape = ov_opset.shape_of(x_perm) d1 = ov_opset.gather( @@ -101,7 +105,12 @@ def diagonal(x, offset=0, axis1=0, axis2=1): end_mask=[0], ) target_shape = ov_opset.concat( - [batch_shape, ov_opset.reshape(L, [1]), ov_opset.constant([2], dtype="i64")], 0 + [ + batch_shape, + ov_opset.reshape(L, [1]), + ov_opset.constant([2], dtype="i64"), + ], + 0, ) bcast_idx = ov_opset.broadcast(diag_idx, target_shape) @@ -110,8 +119,6 @@ def diagonal(x, offset=0, axis1=0, axis2=1): return OpenVINOKerasTensor(gathered) - - def add(x1, x2): element_type = None if isinstance(x1, OpenVINOKerasTensor): @@ -771,8 +778,6 @@ def deg2rad(x): def diag(x, k=0): raise NotImplementedError("`diag` is not supported with openvino backend") - - def diff(a, n=1, axis=-1): if n == 0: return OpenVINOKerasTensor(get_ov_output(a)) From f450869e198589fad9713f531c20931f96ea72a0 Mon Sep 17 00:00:00 2001 From: arjunverma2004 <163740749+arjunverma2004@users.noreply.github.com> Date: Fri, 15 Aug 2025 12:08:33 +0000 Subject: [PATCH 6/8] Run api-gen hook and update API directory --- keras/src/backend/openvino/numpy.py | 32 ++++++++++++++++++----------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index 86b7c7fcae9f..bc01bbe93ac4 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -1,6 +1,7 @@ import numpy as np import openvino.opset14 as ov_opset from openvino import Type +from openvino.runtime import opset13 as ov from keras.src.backend import config from keras.src.backend.common import dtypes @@ -15,13 +16,9 @@ from keras.src.backend.openvino.core import convert_to_tensor from keras.src.backend.openvino.core import get_ov_output from keras.src.backend.openvino.core import ov_to_keras_type -import numpy as np -from openvino.runtime import opset13 as ov - def diagonal(x, offset=0, axis1=0, axis2=1): - """OpenVINO backend decomposition for keras.ops.diagonal.""" x_node = ov.constant(x) # -> ov.Node offset_const = ov_opset.constant(int(offset), dtype="i64") @@ -30,15 +27,18 @@ def diagonal(x, offset=0, axis1=0, axis2=1): rank = ov_opset.shape_of(shape) # scalar i64 (len of shape) rank_val = ov_opset.squeeze(rank) # [] -> scalar axis1_node = ov_opset.floor_mod( - ov_opset.add(ov_opset.constant(int(axis1), dtype="i64"), rank_val), rank_val + ov_opset.add(ov_opset.constant(int(axis1), dtype="i64"), rank_val), + rank_val, ) axis2_node = ov_opset.floor_mod( - ov_opset.add(ov_opset.constant(int(axis2), dtype="i64"), rank_val), rank_val + ov_opset.add(ov_opset.constant(int(axis2), dtype="i64"), rank_val), + rank_val, ) - arange = ov_opset.range( - ov_opset.constant(0, dtype="i64"), rank_val, ov_opset.constant(1, dtype="i64") + ov_opset.constant(0, dtype="i64"), + rank_val, + ov_opset.constant(1, dtype="i64"), ) mask1 = ov_opset.equal(arange, axis1_node) mask2 = ov_opset.equal(arange, axis2_node) @@ -47,7 +47,12 @@ def diagonal(x, offset=0, axis1=0, axis2=1): ov_opset.non_zero(not12), [1] ) # gather positions != axis1, axis2 perm = ov_opset.concat( - [others, ov_opset.reshape(axis1_node, [1]), ov_opset.reshape(axis2_node, [1])], 0 + [ + others, + ov_opset.reshape(axis1_node, [1]), + ov_opset.reshape(axis2_node, [1]), + ], + 0, ) x_perm = ov_opset.transpose(x_node, perm) @@ -101,7 +106,12 @@ def diagonal(x, offset=0, axis1=0, axis2=1): end_mask=[0], ) target_shape = ov_opset.concat( - [batch_shape, ov_opset.reshape(L, [1]), ov_opset.constant([2], dtype="i64")], 0 + [ + batch_shape, + ov_opset.reshape(L, [1]), + ov_opset.constant([2], dtype="i64"), + ], + 0, ) bcast_idx = ov_opset.broadcast(diag_idx, target_shape) @@ -111,7 +121,6 @@ def diagonal(x, offset=0, axis1=0, axis2=1): return OpenVINOKerasTensor(gathered) - def add(x1, x2): element_type = None if isinstance(x1, OpenVINOKerasTensor): @@ -772,7 +781,6 @@ def diag(x, k=0): raise NotImplementedError("`diag` is not supported with openvino backend") - def diff(a, n=1, axis=-1): if n == 0: return OpenVINOKerasTensor(get_ov_output(a)) From 5ac523ef0bad8150f9b30e0e44f2f8d70eb28fba Mon Sep 17 00:00:00 2001 From: arjunverma2004 <163740749+arjunverma2004@users.noreply.github.com> Date: Sat, 23 Aug 2025 09:18:48 +0000 Subject: [PATCH 7/8] Run api-gen and commit generated API directory --- .github/workflows/scripts/labeler.js | 70 ++++++++++++++++++---------- 1 file changed, 46 insertions(+), 24 deletions(-) diff --git a/.github/workflows/scripts/labeler.js b/.github/workflows/scripts/labeler.js index 769683174688..a9084c187b7f 100644 --- a/.github/workflows/scripts/labeler.js +++ b/.github/workflows/scripts/labeler.js @@ -13,37 +13,59 @@ You may obtain a copy of the License at limitations under the License. */ - /** * Invoked from labeler.yaml file to add * label 'Gemma' to the issue and PR for which have gemma keyword present. * @param {!Object.} github contains pre defined functions. - * context Information about the workflow run. + * context Information about the workflow run. */ module.exports = async ({ github, context }) => { - const issue_title = context.payload.issue ? context.payload.issue.title : context.payload.pull_request.title - const issue_description = context.payload.issue ? context.payload.issue.body : context.payload.pull_request.body - const issue_number = context.payload.issue ? context.payload.issue.number : context.payload.pull_request.number + // Determine if the event is an issue or a pull request. + const isIssue = !!context.payload.issue; + + // Get the issue/PR title, description, and number from the payload. + // Use an empty string for the description if it's null to prevent runtime errors. + const issue_title = isIssue ? context.payload.issue.title : context.payload.pull_request.title; + const issue_description = (isIssue ? context.payload.issue.body : context.payload.pull_request.body) || ''; + const issue_number = isIssue ? context.payload.issue.number : context.payload.pull_request.number; + + // Define the keyword-to-label mapping. const keyword_label = { - gemma:'Gemma' - } - const labelsToAdd = [] - console.log(issue_title,issue_description,issue_number) + gemma: 'Gemma' + }; - for(const [keyword, label] of Object.entries(keyword_label)){ - if(issue_title.toLowerCase().indexOf(keyword) !=-1 || issue_description.toLowerCase().indexOf(keyword) !=-1 ){ - console.log(`'${keyword}'keyword is present inside the title or description. Pushing label '${label}' to row.`) - labelsToAdd.push(label) + // Array to hold labels that need to be added. + const labelsToAdd = []; + + console.log(`Processing event for issue/PR #${issue_number}: "${issue_title}"`); + + // Loop through the keywords and check if they exist in the title or description. + for (const [keyword, label] of Object.entries(keyword_label)) { + // Use .includes() for a cleaner and more modern check. + if (issue_title.toLowerCase().includes(keyword) || issue_description.toLowerCase().includes(keyword)) { + console.log(`'${keyword}' keyword is present in the title or description. Pushing label '${label}' to the array.`); + labelsToAdd.push(label); + } + } + + // Add labels if the labelsToAdd array is not empty. + if (labelsToAdd.length > 0) { + console.log(`Adding labels ${labelsToAdd} to issue/PR '#${issue_number}'.`); + + try { + // Await the asynchronous API call to ensure it completes. + await github.rest.issues.addLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: issue_number, // Use the correct issue_number variable + labels: labelsToAdd + }); + console.log(`Successfully added labels.`); + } catch (error) { + console.error(`Failed to add labels: ${error.message}`); + } + } else { + console.log("No matching keywords found. No labels to add."); } - } - if(labelsToAdd.length > 0){ - console.log(`Adding labels ${labelsToAdd} to the issue '#${issue_number}'.`) - github.rest.issues.addLabels({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - labels: labelsToAdd - }) - } -}; \ No newline at end of file +}; From 35071fd7a18bc6e525e523b2ed72e1ad897030cd Mon Sep 17 00:00:00 2001 From: arjunverma2004 <163740749+arjunverma2004@users.noreply.github.com> Date: Sat, 23 Aug 2025 17:18:28 +0530 Subject: [PATCH 8/8] Update labeler.js --- .github/workflows/scripts/labeler.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/scripts/labeler.js b/.github/workflows/scripts/labeler.js index a9084c187b7f..17205f47cca8 100644 --- a/.github/workflows/scripts/labeler.js +++ b/.github/workflows/scripts/labeler.js @@ -21,6 +21,7 @@ You may obtain a copy of the License at */ module.exports = async ({ github, context }) => { + // Determine if the event is an issue or a pull request. const isIssue = !!context.payload.issue; @@ -34,7 +35,6 @@ module.exports = async ({ github, context }) => { const keyword_label = { gemma: 'Gemma' }; - // Array to hold labels that need to be added. const labelsToAdd = [];