Skip to content

Commit 8dffc46

Browse files
[OpenVINO backend] support repeat
1 parent be9b002 commit 8dffc46

File tree

2 files changed

+77
-3
lines changed

2 files changed

+77
-3
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ NumpyDtypeTest::test_multiply
4545
NumpyDtypeTest::test_power
4646
NumpyDtypeTest::test_prod
4747
NumpyDtypeTest::test_quantile
48-
NumpyDtypeTest::test_repeat
4948
NumpyDtypeTest::test_roll
5049
NumpyDtypeTest::test_round
5150
NumpyDtypeTest::test_searchsorted
@@ -106,7 +105,6 @@ NumpyOneInputOpsCorrectnessTest::test_pad_uint8_constant_2
106105
NumpyOneInputOpsCorrectnessTest::test_pad_int32_constant_2
107106
NumpyOneInputOpsCorrectnessTest::test_prod
108107
NumpyOneInputOpsCorrectnessTest::test_real
109-
NumpyOneInputOpsCorrectnessTest::test_repeat
110108
NumpyOneInputOpsCorrectnessTest::test_reshape
111109
NumpyOneInputOpsCorrectnessTest::test_roll
112110
NumpyOneInputOpsCorrectnessTest::test_round

keras/src/backend/openvino/numpy.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1291,7 +1291,83 @@ def reciprocal(x):
12911291

12921292

12931293
def repeat(x, repeats, axis=None):
1294-
raise NotImplementedError("`repeat` is not supported with openvino backend")
1294+
x = get_ov_output(x)
1295+
1296+
if axis is not None and axis < 0:
1297+
axis += len(x.get_partial_shape())
1298+
1299+
if axis is None:
1300+
x = ov_opset.reshape(
1301+
x, ov_opset.constant([-1], Type.i32), special_zero=False
1302+
).output(0)
1303+
axis = 0
1304+
1305+
if isinstance(repeats, (int, np.integer)) or (
1306+
isinstance(repeats, np.ndarray)
1307+
and repeats.ndim == 1
1308+
and repeats.size == 1
1309+
):
1310+
repeats_val = (
1311+
int(repeats) if isinstance(repeats, np.ndarray) else repeats
1312+
)
1313+
input_shape = ov_opset.shape_of(x, Type.i32).output(0)
1314+
dim_len = ov_opset.gather(
1315+
input_shape,
1316+
ov_opset.constant([axis], Type.i32),
1317+
ov_opset.constant(0, Type.i32),
1318+
).output(0)
1319+
dim_len = ov_opset.squeeze(
1320+
dim_len, ov_opset.constant([0], Type.i32)
1321+
).output(0)
1322+
idx_range = ov_opset.range(
1323+
ov_opset.constant(0, Type.i32),
1324+
dim_len,
1325+
ov_opset.constant(1, Type.i32),
1326+
output_type=Type.i32,
1327+
).output(0)
1328+
idx_range = ov_opset.unsqueeze(
1329+
idx_range, ov_opset.constant([1], Type.i32)
1330+
).output(0)
1331+
tiled = ov_opset.tile(
1332+
idx_range, ov_opset.constant([1, repeats_val], Type.i32)
1333+
).output(0)
1334+
idx = ov_opset.reshape(
1335+
tiled, ov_opset.constant([-1], Type.i32), special_zero=False
1336+
).output(0)
1337+
result = ov_opset.gather(
1338+
x, idx, ov_opset.constant(axis, Type.i32)
1339+
).output(0)
1340+
return OpenVINOKerasTensor(result)
1341+
1342+
repeats_np = np.array(repeats)
1343+
if repeats_np.ndim != 1:
1344+
raise NotImplementedError("Only 1D repeats arrays are supported.")
1345+
input_shape = ov_opset.shape_of(x, Type.i32).output(0)
1346+
axis_len = ov_opset.gather(
1347+
input_shape,
1348+
ov_opset.constant([axis], Type.i32),
1349+
ov_opset.constant(0, Type.i32),
1350+
).output(0)
1351+
axis_len_val = (
1352+
int(axis_len.get_vector()[0])
1353+
if hasattr(axis_len, "get_vector")
1354+
else x.get_partial_shape()[axis]
1355+
)
1356+
if axis_len_val != len(repeats_np):
1357+
raise ValueError("repeats length does not match axis length")
1358+
1359+
gather_indices = np.concatenate(
1360+
[
1361+
np.full(r, i, dtype=np.int32)
1362+
for i, r in enumerate(repeats_np)
1363+
if r > 0
1364+
]
1365+
)
1366+
gather_indices_ov = ov_opset.constant(gather_indices, Type.i32).output(0)
1367+
result = ov_opset.gather(
1368+
x, gather_indices_ov, ov_opset.constant(axis, Type.i32)
1369+
).output(0)
1370+
return OpenVINOKerasTensor(result)
12951371

12961372

12971373
def reshape(x, newshape):

0 commit comments

Comments
 (0)