Skip to content

Commit 96408fa

Browse files
[OpenVINO backend] support triu [OpenVINO backend] support tri, triu, and tril
1 parent be9b002 commit 96408fa

File tree

2 files changed

+214
-8
lines changed

2 files changed

+214
-8
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
NumPyTestRot90
22
NumpyArrayCreateOpsCorrectnessTest::test_eye
3-
NumpyArrayCreateOpsCorrectnessTest::test_tri
43
NumpyDtypeTest::test_absolute_bool
54
NumpyDtypeTest::test_add_
65
NumpyDtypeTest::test_all
@@ -60,7 +59,6 @@ NumpyDtypeTest::test_take_along_axis
6059
NumpyDtypeTest::test_tensordot_
6160
NumpyDtypeTest::test_tile
6261
NumpyDtypeTest::test_trace
63-
NumpyDtypeTest::test_tri
6462
NumpyDtypeTest::test_trunc
6563
NumpyDtypeTest::test_unravel
6664
NumpyDtypeTest::test_var
@@ -124,9 +122,6 @@ NumpyOneInputOpsCorrectnessTest::test_swapaxes
124122
NumpyOneInputOpsCorrectnessTest::test_tile
125123
NumpyOneInputOpsCorrectnessTest::test_trace
126124
NumpyOneInputOpsCorrectnessTest::test_transpose
127-
NumpyOneInputOpsCorrectnessTest::test_tril
128-
NumpyOneInputOpsCorrectnessTest::test_tril_in_layer
129-
NumpyOneInputOpsCorrectnessTest::test_triu
130125
NumpyOneInputOpsCorrectnessTest::test_trunc
131126
NumpyOneInputOpsCorrectnessTest::test_unravel_index
132127
NumpyOneInputOpsCorrectnessTest::test_var

keras/src/backend/openvino/numpy.py

Lines changed: 214 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,15 +1484,226 @@ def trace(x, offset=0, axis1=0, axis2=1):
14841484

14851485

14861486
def tri(N, M=None, k=0, dtype=None):
1487-
raise NotImplementedError("`tri` is not supported with openvino backend")
1487+
if M is None:
1488+
M = N
1489+
if dtype is None:
1490+
dtype = "float32"
1491+
1492+
ov_dtype = OPENVINO_DTYPES[dtype]
1493+
1494+
N = ov_opset.constant(N, Type.i32)
1495+
M = ov_opset.constant(M, Type.i32)
1496+
k = ov_opset.constant(k, Type.i32)
1497+
1498+
row_range = ov_opset.range(
1499+
ov_opset.constant(0, Type.i32),
1500+
N,
1501+
ov_opset.constant(1, Type.i32),
1502+
output_type=Type.i32,
1503+
)
1504+
col_range = ov_opset.range(
1505+
ov_opset.constant(0, Type.i32),
1506+
M,
1507+
ov_opset.constant(1, Type.i32),
1508+
output_type=Type.i32,
1509+
)
1510+
1511+
row_idx = ov_opset.unsqueeze(row_range, ov_opset.constant([1], Type.i32))
1512+
col_idx = ov_opset.unsqueeze(col_range, ov_opset.constant([0], Type.i32))
1513+
1514+
target_shape = ov_opset.concat(
1515+
[ov_opset.unsqueeze(N, [0]), ov_opset.unsqueeze(M, [0])], axis=0
1516+
)
1517+
1518+
row_idx = ov_opset.broadcast(row_idx, target_shape)
1519+
col_idx = ov_opset.broadcast(col_idx, target_shape)
1520+
1521+
mask = ov_opset.less_equal(col_idx, ov_opset.add(row_idx, k))
1522+
1523+
if ov_dtype == Type.boolean:
1524+
result = mask
1525+
else:
1526+
result = ov_opset.convert(mask, ov_dtype)
1527+
1528+
return OpenVINOKerasTensor(result.output(0))
14881529

14891530

14901531
def tril(x, k=0):
1491-
raise NotImplementedError("`tril` is not supported with openvino backend")
1532+
def get_shape_dims(x):
1533+
shape = ov_opset.shape_of(x, Type.i32)
1534+
rank_tensor = ov_opset.shape_of(shape, Type.i32)
1535+
rank_scalar = ov_opset.squeeze(
1536+
rank_tensor, ov_opset.constant([0], Type.i32)
1537+
)
1538+
indices = ov_opset.range(
1539+
ov_opset.constant(0, Type.i32),
1540+
rank_scalar,
1541+
ov_opset.constant(1, Type.i32),
1542+
output_type=Type.i32,
1543+
)
1544+
return ov_opset.gather(shape, indices, axis=0)
1545+
1546+
x = get_ov_output(x)
1547+
ov_type = x.get_element_type()
1548+
input_shape = ov_opset.shape_of(x, Type.i32)
1549+
shape = get_shape_dims(x)
1550+
1551+
zero_const = ov_opset.constant(0, Type.i32)
1552+
minus2 = ov_opset.constant([-2], Type.i32)
1553+
minus1 = ov_opset.constant([-1], Type.i32)
1554+
1555+
M = ov_opset.squeeze(
1556+
ov_opset.gather(shape, minus2, zero_const),
1557+
ov_opset.constant([0], Type.i32),
1558+
)
1559+
N = ov_opset.squeeze(
1560+
ov_opset.gather(shape, minus1, zero_const),
1561+
ov_opset.constant([0], Type.i32),
1562+
)
1563+
1564+
row_range = ov_opset.range(
1565+
ov_opset.constant(0, Type.i32),
1566+
M,
1567+
ov_opset.constant(1, Type.i32),
1568+
output_type=Type.i32,
1569+
)
1570+
col_range = ov_opset.range(
1571+
ov_opset.constant(0, Type.i32),
1572+
N,
1573+
ov_opset.constant(1, Type.i32),
1574+
output_type=Type.i32,
1575+
)
1576+
1577+
row_idx = ov_opset.unsqueeze(row_range, ov_opset.constant([1], Type.i32))
1578+
col_idx = ov_opset.unsqueeze(col_range, ov_opset.constant([0], Type.i32))
1579+
1580+
M_1d = ov_opset.unsqueeze(M, ov_opset.constant([0], Type.i32))
1581+
N_1d = ov_opset.unsqueeze(N, ov_opset.constant([0], Type.i32))
1582+
target_shape = ov_opset.concat([M_1d, N_1d], axis=0)
1583+
1584+
row_idx = ov_opset.broadcast(row_idx, target_shape)
1585+
col_idx = ov_opset.broadcast(col_idx, target_shape)
1586+
1587+
k_const = ov_opset.constant(k, Type.i32)
1588+
mask = ov_opset.less_equal(col_idx, ov_opset.add(row_idx, k_const))
1589+
mask = ov_opset.convert(mask, ov_type)
1590+
1591+
shape_rank_tensor = ov_opset.shape_of(input_shape, Type.i32)
1592+
shape_rank = ov_opset.squeeze(
1593+
shape_rank_tensor, ov_opset.constant([0], Type.i32)
1594+
)
1595+
batch_dims_count = ov_opset.subtract(
1596+
shape_rank, ov_opset.constant(2, Type.i32)
1597+
)
1598+
batch_dims_count = ov_opset.squeeze(
1599+
batch_dims_count, ov_opset.constant([0], Type.i32)
1600+
)
1601+
1602+
batch_indices = ov_opset.range(
1603+
start=ov_opset.constant(0, Type.i32),
1604+
stop=batch_dims_count,
1605+
step=ov_opset.constant(1, Type.i32),
1606+
output_type=Type.i32,
1607+
)
1608+
1609+
batch_shape = ov_opset.gather(input_shape, batch_indices, axis=0)
1610+
full_mask_shape = ov_opset.concat([batch_shape, M_1d, N_1d], axis=0)
1611+
mask = ov_opset.broadcast(mask, full_mask_shape)
1612+
1613+
if ov_type == Type.boolean:
1614+
out = ov_opset.logical_and(x, mask)
1615+
else:
1616+
out = ov_opset.multiply(x, mask)
1617+
return OpenVINOKerasTensor(out.output(0))
14921618

14931619

14941620
def triu(x, k=0):
1495-
raise NotImplementedError("`triu` is not supported with openvino backend")
1621+
def get_shape_dims(x):
1622+
shape = ov_opset.shape_of(x, Type.i32)
1623+
rank_tensor = ov_opset.shape_of(shape, Type.i32)
1624+
rank_scalar = ov_opset.squeeze(
1625+
rank_tensor, ov_opset.constant([0], Type.i32)
1626+
)
1627+
indices = ov_opset.range(
1628+
ov_opset.constant(0, Type.i32),
1629+
rank_scalar,
1630+
ov_opset.constant(1, Type.i32),
1631+
output_type=Type.i32,
1632+
)
1633+
return ov_opset.gather(shape, indices, axis=0)
1634+
1635+
x = get_ov_output(x)
1636+
ov_type = x.get_element_type()
1637+
input_shape = ov_opset.shape_of(x, Type.i32)
1638+
shape = get_shape_dims(x)
1639+
1640+
zero_const = ov_opset.constant(0, Type.i32)
1641+
minus2 = ov_opset.constant([-2], Type.i32)
1642+
minus1 = ov_opset.constant([-1], Type.i32)
1643+
1644+
M = ov_opset.squeeze(
1645+
ov_opset.gather(shape, minus2, zero_const),
1646+
ov_opset.constant([0], Type.i32),
1647+
)
1648+
N = ov_opset.squeeze(
1649+
ov_opset.gather(shape, minus1, zero_const),
1650+
ov_opset.constant([0], Type.i32),
1651+
)
1652+
1653+
row_range = ov_opset.range(
1654+
ov_opset.constant(0, Type.i32),
1655+
M,
1656+
ov_opset.constant(1, Type.i32),
1657+
output_type=Type.i32,
1658+
)
1659+
col_range = ov_opset.range(
1660+
ov_opset.constant(0, Type.i32),
1661+
N,
1662+
ov_opset.constant(1, Type.i32),
1663+
output_type=Type.i32,
1664+
)
1665+
1666+
row_idx = ov_opset.unsqueeze(row_range, ov_opset.constant([1], Type.i32))
1667+
col_idx = ov_opset.unsqueeze(col_range, ov_opset.constant([0], Type.i32))
1668+
1669+
M_1d = ov_opset.unsqueeze(M, ov_opset.constant([0], Type.i32))
1670+
N_1d = ov_opset.unsqueeze(N, ov_opset.constant([0], Type.i32))
1671+
target_shape = ov_opset.concat([M_1d, N_1d], axis=0)
1672+
1673+
row_idx = ov_opset.broadcast(row_idx, target_shape)
1674+
col_idx = ov_opset.broadcast(col_idx, target_shape)
1675+
1676+
k_const = ov_opset.constant(k, Type.i32)
1677+
mask = ov_opset.greater_equal(col_idx, ov_opset.add(row_idx, k_const))
1678+
mask = ov_opset.convert(mask, ov_type)
1679+
1680+
shape_rank_tensor = ov_opset.shape_of(input_shape, Type.i32)
1681+
shape_rank = ov_opset.squeeze(
1682+
shape_rank_tensor, ov_opset.constant([0], Type.i32)
1683+
)
1684+
batch_dims_count = ov_opset.subtract(
1685+
shape_rank, ov_opset.constant(2, Type.i32)
1686+
)
1687+
batch_dims_count = ov_opset.squeeze(
1688+
batch_dims_count, ov_opset.constant([0], Type.i32)
1689+
)
1690+
1691+
batch_indices = ov_opset.range(
1692+
start=ov_opset.constant(0, Type.i32),
1693+
stop=batch_dims_count,
1694+
step=ov_opset.constant(1, Type.i32),
1695+
output_type=Type.i32,
1696+
)
1697+
1698+
batch_shape = ov_opset.gather(input_shape, batch_indices, axis=0)
1699+
full_mask_shape = ov_opset.concat([batch_shape, M_1d, N_1d], axis=0)
1700+
mask = ov_opset.broadcast(mask, full_mask_shape)
1701+
1702+
if ov_type == Type.boolean:
1703+
out = ov_opset.logical_and(x, mask)
1704+
else:
1705+
out = ov_opset.multiply(x, mask)
1706+
return OpenVINOKerasTensor(out.output(0))
14961707

14971708

14981709
def vdot(x1, x2):

0 commit comments

Comments
 (0)