Skip to content

Commit a821043

Browse files
Merge branch 'support_triu' into gsoc2025
2 parents 74b5997 + 4b66e60 commit a821043

File tree

2 files changed

+82
-8
lines changed

2 files changed

+82
-8
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
NumPyTestRot90
22
NumpyArrayCreateOpsCorrectnessTest::test_eye
3-
NumpyArrayCreateOpsCorrectnessTest::test_tri
43
NumpyDtypeTest::test_absolute_bool
54
NumpyDtypeTest::test_add_
65
NumpyDtypeTest::test_all
@@ -61,7 +60,6 @@ NumpyDtypeTest::test_swapaxes
6160
NumpyDtypeTest::test_tensordot_
6261
NumpyDtypeTest::test_tile
6362
NumpyDtypeTest::test_trace
64-
NumpyDtypeTest::test_tri
6563
NumpyDtypeTest::test_trunc
6664
NumpyDtypeTest::test_unravel
6765
NumpyDtypeTest::test_var
@@ -126,9 +124,6 @@ NumpyOneInputOpsCorrectnessTest::test_swapaxes
126124
NumpyOneInputOpsCorrectnessTest::test_tile
127125
NumpyOneInputOpsCorrectnessTest::test_trace
128126
NumpyOneInputOpsCorrectnessTest::test_transpose
129-
NumpyOneInputOpsCorrectnessTest::test_tril
130-
NumpyOneInputOpsCorrectnessTest::test_tril_in_layer
131-
NumpyOneInputOpsCorrectnessTest::test_triu
132127
NumpyOneInputOpsCorrectnessTest::test_trunc
133128
NumpyOneInputOpsCorrectnessTest::test_unravel_index
134129
NumpyOneInputOpsCorrectnessTest::test_var

keras/src/backend/openvino/numpy.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,15 +1654,94 @@ def trace(x, offset=0, axis1=0, axis2=1):
16541654

16551655

16561656
def tri(N, M=None, k=0, dtype=None):
1657-
raise NotImplementedError("`tri` is not supported with openvino backend")
1657+
if M is None:
1658+
M = N
1659+
if dtype is None:
1660+
dtype = "float32"
1661+
1662+
ov_dtype = OPENVINO_DTYPES[dtype]
1663+
1664+
def ensure_constant(value, default_type=Type.i32):
1665+
if isinstance(value, (int, float)):
1666+
return ov_opset.constant(value, default_type)
1667+
elif hasattr(value, "get_element_type"):
1668+
if value.get_element_type() != Type.i32:
1669+
value = ov_opset.convert(value, Type.i32)
1670+
return ov_opset.squeeze(value, ov_opset.constant([0], Type.i32))
1671+
else:
1672+
return ov_opset.constant(value, default_type)
1673+
1674+
N_const = ensure_constant(N)
1675+
M_const = ensure_constant(M)
1676+
k_const = ensure_constant(k)
1677+
1678+
# Create row and column indices
1679+
row_range = ov_opset.range(
1680+
ov_opset.constant(0, Type.i32),
1681+
N_const,
1682+
ov_opset.constant(1, Type.i32),
1683+
output_type=Type.i32,
1684+
)
1685+
col_range = ov_opset.range(
1686+
ov_opset.constant(0, Type.i32),
1687+
M_const,
1688+
ov_opset.constant(1, Type.i32),
1689+
output_type=Type.i32,
1690+
)
1691+
1692+
# Reshape indices for broadcasting
1693+
row_idx = ov_opset.unsqueeze(row_range, ov_opset.constant([1], Type.i32))
1694+
col_idx = ov_opset.unsqueeze(col_range, ov_opset.constant([0], Type.i32))
1695+
1696+
mask = ov_opset.less_equal(col_idx, ov_opset.add(row_idx, k_const))
1697+
1698+
if ov_dtype == Type.boolean:
1699+
result = mask
1700+
else:
1701+
result = ov_opset.convert(mask, ov_dtype)
1702+
1703+
return OpenVINOKerasTensor(result.output(0))
16581704

16591705

16601706
def tril(x, k=0):
1661-
raise NotImplementedError("`tril` is not supported with openvino backend")
1707+
x = get_ov_output(x)
1708+
ov_type = x.get_element_type()
1709+
shape = ov_opset.shape_of(x, Type.i32)
1710+
zero_const = ov_opset.constant(0, Type.i32)
1711+
minus2 = ov_opset.constant([-2], Type.i32)
1712+
minus1 = ov_opset.constant([-1], Type.i32)
1713+
M = ov_opset.squeeze(ov_opset.gather(shape, minus2, zero_const), zero_const)
1714+
N = ov_opset.squeeze(ov_opset.gather(shape, minus1, zero_const), zero_const)
1715+
tri_mask = tri(M, N, k=k, dtype="bool").output
1716+
mask = ov_opset.convert(tri_mask, ov_type)
1717+
if ov_type == Type.boolean:
1718+
out = ov_opset.logical_and(x, mask)
1719+
else:
1720+
out = ov_opset.multiply(x, mask)
1721+
return OpenVINOKerasTensor(out.output(0))
16621722

16631723

16641724
def triu(x, k=0):
1665-
raise NotImplementedError("`triu` is not supported with openvino backend")
1725+
x = get_ov_output(x)
1726+
ov_type = x.get_element_type()
1727+
shape = ov_opset.shape_of(x, Type.i32)
1728+
zero_const = ov_opset.constant(0, Type.i32)
1729+
minus2 = ov_opset.constant([-2], Type.i32)
1730+
minus1 = ov_opset.constant([-1], Type.i32)
1731+
M = ov_opset.squeeze(ov_opset.gather(shape, minus2, zero_const), zero_const)
1732+
N = ov_opset.squeeze(ov_opset.gather(shape, minus1, zero_const), zero_const)
1733+
tri_mask = tri(M, N, k=k - 1, dtype="bool").output
1734+
if ov_type == Type.boolean:
1735+
mask = ov_opset.logical_not(tri_mask)
1736+
else:
1737+
const_one = ov_opset.constant(1, ov_type)
1738+
converted_mask = ov_opset.convert(tri_mask, ov_type)
1739+
mask = ov_opset.subtract(const_one, converted_mask)
1740+
if ov_type == Type.boolean:
1741+
out = ov_opset.logical_and(x, mask)
1742+
else:
1743+
out = ov_opset.multiply(x, mask)
1744+
return OpenVINOKerasTensor(out.output(0))
16661745

16671746

16681747
def vdot(x1, x2):

0 commit comments

Comments
 (0)