Skip to content

Commit 3bbbf99

Browse files
[OpenVINO backend] support triu
1 parent be9b002 commit 3bbbf99

File tree

2 files changed

+84
-2
lines changed

2 files changed

+84
-2
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ NumpyOneInputOpsCorrectnessTest::test_trace
126126
NumpyOneInputOpsCorrectnessTest::test_transpose
127127
NumpyOneInputOpsCorrectnessTest::test_tril
128128
NumpyOneInputOpsCorrectnessTest::test_tril_in_layer
129-
NumpyOneInputOpsCorrectnessTest::test_triu
130129
NumpyOneInputOpsCorrectnessTest::test_trunc
131130
NumpyOneInputOpsCorrectnessTest::test_unravel_index
132131
NumpyOneInputOpsCorrectnessTest::test_var

keras/src/backend/openvino/numpy.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1492,7 +1492,90 @@ def tril(x, k=0):
14921492

14931493

14941494
def triu(x, k=0):
1495-
raise NotImplementedError("`triu` is not supported with openvino backend")
1495+
def get_shape_dims(x):
1496+
shape = ov_opset.shape_of(x, Type.i32)
1497+
rank_tensor = ov_opset.shape_of(shape, Type.i32)
1498+
rank_scalar = ov_opset.squeeze(
1499+
rank_tensor, ov_opset.constant([0], Type.i32)
1500+
)
1501+
indices = ov_opset.range(
1502+
ov_opset.constant(0, Type.i32),
1503+
rank_scalar,
1504+
ov_opset.constant(1, Type.i32),
1505+
output_type=Type.i32,
1506+
)
1507+
return ov_opset.gather(shape, indices, axis=0)
1508+
1509+
x = get_ov_output(x)
1510+
ov_type = x.get_element_type()
1511+
input_shape = ov_opset.shape_of(x, Type.i32)
1512+
shape = get_shape_dims(x)
1513+
1514+
zero_const = ov_opset.constant(0, Type.i32)
1515+
minus2 = ov_opset.constant([-2], Type.i32)
1516+
minus1 = ov_opset.constant([-1], Type.i32)
1517+
1518+
M = ov_opset.squeeze(
1519+
ov_opset.gather(shape, minus2, zero_const),
1520+
ov_opset.constant([0], Type.i32),
1521+
)
1522+
N = ov_opset.squeeze(
1523+
ov_opset.gather(shape, minus1, zero_const),
1524+
ov_opset.constant([0], Type.i32),
1525+
)
1526+
1527+
row_range = ov_opset.range(
1528+
ov_opset.constant(0, Type.i32),
1529+
M,
1530+
ov_opset.constant(1, Type.i32),
1531+
output_type=Type.i32,
1532+
)
1533+
col_range = ov_opset.range(
1534+
ov_opset.constant(0, Type.i32),
1535+
N,
1536+
ov_opset.constant(1, Type.i32),
1537+
output_type=Type.i32,
1538+
)
1539+
1540+
row_idx = ov_opset.unsqueeze(row_range, ov_opset.constant([1], Type.i32))
1541+
col_idx = ov_opset.unsqueeze(col_range, ov_opset.constant([0], Type.i32))
1542+
1543+
M_1d = ov_opset.unsqueeze(M, ov_opset.constant([0], Type.i32))
1544+
N_1d = ov_opset.unsqueeze(N, ov_opset.constant([0], Type.i32))
1545+
target_shape = ov_opset.concat([M_1d, N_1d], axis=0)
1546+
1547+
row_idx = ov_opset.broadcast(row_idx, target_shape)
1548+
col_idx = ov_opset.broadcast(col_idx, target_shape)
1549+
1550+
k_const = ov_opset.constant(k, Type.i32)
1551+
mask = ov_opset.greater_equal(col_idx, ov_opset.add(row_idx, k_const))
1552+
mask = ov_opset.convert(mask, ov_type)
1553+
1554+
shape_rank_tensor = ov_opset.shape_of(input_shape, Type.i32)
1555+
shape_rank = ov_opset.squeeze(
1556+
shape_rank_tensor, ov_opset.constant([0], Type.i32)
1557+
)
1558+
batch_dims_count = ov_opset.subtract(
1559+
shape_rank, ov_opset.constant(2, Type.i32)
1560+
)
1561+
batch_dims_count = ov_opset.squeeze(
1562+
batch_dims_count, ov_opset.constant([0], Type.i32)
1563+
)
1564+
1565+
batch_indices = ov_opset.range(
1566+
start=ov_opset.constant(0, Type.i32),
1567+
stop=batch_dims_count,
1568+
step=ov_opset.constant(1, Type.i32),
1569+
output_type=Type.i32,
1570+
)
1571+
1572+
batch_shape = ov_opset.gather(input_shape, batch_indices, axis=0)
1573+
full_mask_shape = ov_opset.concat([batch_shape, M_1d, N_1d], axis=0)
1574+
mask = ov_opset.broadcast(mask, full_mask_shape)
1575+
if ov_type == Type.boolean:
1576+
mask = ov_opset.convert(mask, Type.f32)
1577+
x = ov_opset.convert(x, Type.f32)
1578+
return OpenVINOKerasTensor(ov_opset.multiply(x, mask).output(0))
14961579

14971580

14981581
def vdot(x1, x2):

0 commit comments

Comments
 (0)