Skip to content

Commit 4b66e60

Browse files
[OpenVINO backend] support tri, triu, and tril
1 parent d55a767 commit 4b66e60

File tree

2 files changed

+82
-8
lines changed

2 files changed

+82
-8
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
NumPyTestRot90
22
NumpyArrayCreateOpsCorrectnessTest::test_eye
3-
NumpyArrayCreateOpsCorrectnessTest::test_tri
43
NumpyDtypeTest::test_absolute_bool
54
NumpyDtypeTest::test_add_
65
NumpyDtypeTest::test_all
@@ -62,7 +61,6 @@ NumpyDtypeTest::test_take_along_axis
6261
NumpyDtypeTest::test_tensordot_
6362
NumpyDtypeTest::test_tile
6463
NumpyDtypeTest::test_trace
65-
NumpyDtypeTest::test_tri
6664
NumpyDtypeTest::test_trunc
6765
NumpyDtypeTest::test_unravel
6866
NumpyDtypeTest::test_var
@@ -128,9 +126,6 @@ NumpyOneInputOpsCorrectnessTest::test_swapaxes
128126
NumpyOneInputOpsCorrectnessTest::test_tile
129127
NumpyOneInputOpsCorrectnessTest::test_trace
130128
NumpyOneInputOpsCorrectnessTest::test_transpose
131-
NumpyOneInputOpsCorrectnessTest::test_tril
132-
NumpyOneInputOpsCorrectnessTest::test_tril_in_layer
133-
NumpyOneInputOpsCorrectnessTest::test_triu
134129
NumpyOneInputOpsCorrectnessTest::test_trunc
135130
NumpyOneInputOpsCorrectnessTest::test_unravel_index
136131
NumpyOneInputOpsCorrectnessTest::test_var

keras/src/backend/openvino/numpy.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,15 +1494,94 @@ def trace(x, offset=0, axis1=0, axis2=1):
14941494

14951495

14961496
def tri(N, M=None, k=0, dtype=None):
1497-
raise NotImplementedError("`tri` is not supported with openvino backend")
1497+
if M is None:
1498+
M = N
1499+
if dtype is None:
1500+
dtype = "float32"
1501+
1502+
ov_dtype = OPENVINO_DTYPES[dtype]
1503+
1504+
def ensure_constant(value, default_type=Type.i32):
1505+
if isinstance(value, (int, float)):
1506+
return ov_opset.constant(value, default_type)
1507+
elif hasattr(value, "get_element_type"):
1508+
if value.get_element_type() != Type.i32:
1509+
value = ov_opset.convert(value, Type.i32)
1510+
return ov_opset.squeeze(value, ov_opset.constant([0], Type.i32))
1511+
else:
1512+
return ov_opset.constant(value, default_type)
1513+
1514+
N_const = ensure_constant(N)
1515+
M_const = ensure_constant(M)
1516+
k_const = ensure_constant(k)
1517+
1518+
# Create row and column indices
1519+
row_range = ov_opset.range(
1520+
ov_opset.constant(0, Type.i32),
1521+
N_const,
1522+
ov_opset.constant(1, Type.i32),
1523+
output_type=Type.i32,
1524+
)
1525+
col_range = ov_opset.range(
1526+
ov_opset.constant(0, Type.i32),
1527+
M_const,
1528+
ov_opset.constant(1, Type.i32),
1529+
output_type=Type.i32,
1530+
)
1531+
1532+
# Reshape indices for broadcasting
1533+
row_idx = ov_opset.unsqueeze(row_range, ov_opset.constant([1], Type.i32))
1534+
col_idx = ov_opset.unsqueeze(col_range, ov_opset.constant([0], Type.i32))
1535+
1536+
mask = ov_opset.less_equal(col_idx, ov_opset.add(row_idx, k_const))
1537+
1538+
if ov_dtype == Type.boolean:
1539+
result = mask
1540+
else:
1541+
result = ov_opset.convert(mask, ov_dtype)
1542+
1543+
return OpenVINOKerasTensor(result.output(0))
14981544

14991545

15001546
def tril(x, k=0):
1501-
raise NotImplementedError("`tril` is not supported with openvino backend")
1547+
x = get_ov_output(x)
1548+
ov_type = x.get_element_type()
1549+
shape = ov_opset.shape_of(x, Type.i32)
1550+
zero_const = ov_opset.constant(0, Type.i32)
1551+
minus2 = ov_opset.constant([-2], Type.i32)
1552+
minus1 = ov_opset.constant([-1], Type.i32)
1553+
M = ov_opset.squeeze(ov_opset.gather(shape, minus2, zero_const), zero_const)
1554+
N = ov_opset.squeeze(ov_opset.gather(shape, minus1, zero_const), zero_const)
1555+
tri_mask = tri(M, N, k=k, dtype="bool").output
1556+
mask = ov_opset.convert(tri_mask, ov_type)
1557+
if ov_type == Type.boolean:
1558+
out = ov_opset.logical_and(x, mask)
1559+
else:
1560+
out = ov_opset.multiply(x, mask)
1561+
return OpenVINOKerasTensor(out.output(0))
15021562

15031563

15041564
def triu(x, k=0):
1505-
raise NotImplementedError("`triu` is not supported with openvino backend")
1565+
x = get_ov_output(x)
1566+
ov_type = x.get_element_type()
1567+
shape = ov_opset.shape_of(x, Type.i32)
1568+
zero_const = ov_opset.constant(0, Type.i32)
1569+
minus2 = ov_opset.constant([-2], Type.i32)
1570+
minus1 = ov_opset.constant([-1], Type.i32)
1571+
M = ov_opset.squeeze(ov_opset.gather(shape, minus2, zero_const), zero_const)
1572+
N = ov_opset.squeeze(ov_opset.gather(shape, minus1, zero_const), zero_const)
1573+
tri_mask = tri(M, N, k=k - 1, dtype="bool").output
1574+
if ov_type == Type.boolean:
1575+
mask = ov_opset.logical_not(tri_mask)
1576+
else:
1577+
const_one = ov_opset.constant(1, ov_type)
1578+
converted_mask = ov_opset.convert(tri_mask, ov_type)
1579+
mask = ov_opset.subtract(const_one, converted_mask)
1580+
if ov_type == Type.boolean:
1581+
out = ov_opset.logical_and(x, mask)
1582+
else:
1583+
out = ov_opset.multiply(x, mask)
1584+
return OpenVINOKerasTensor(out.output(0))
15061585

15071586

15081587
def vdot(x1, x2):

0 commit comments

Comments
 (0)