Skip to content

Commit aefa32f

Browse files
[OpenVINO backend] support repeat
1 parent be9b002 commit aefa32f

File tree

2 files changed

+90
-3
lines changed

2 files changed

+90
-3
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ NumpyDtypeTest::test_multiply
4545
NumpyDtypeTest::test_power
4646
NumpyDtypeTest::test_prod
4747
NumpyDtypeTest::test_quantile
48-
NumpyDtypeTest::test_repeat
4948
NumpyDtypeTest::test_roll
5049
NumpyDtypeTest::test_round
5150
NumpyDtypeTest::test_searchsorted
@@ -106,7 +105,6 @@ NumpyOneInputOpsCorrectnessTest::test_pad_uint8_constant_2
106105
NumpyOneInputOpsCorrectnessTest::test_pad_int32_constant_2
107106
NumpyOneInputOpsCorrectnessTest::test_prod
108107
NumpyOneInputOpsCorrectnessTest::test_real
109-
NumpyOneInputOpsCorrectnessTest::test_repeat
110108
NumpyOneInputOpsCorrectnessTest::test_reshape
111109
NumpyOneInputOpsCorrectnessTest::test_roll
112110
NumpyOneInputOpsCorrectnessTest::test_round

keras/src/backend/openvino/numpy.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1291,7 +1291,96 @@ def reciprocal(x):
12911291

12921292

12931293
def repeat(x, repeats, axis=None):
1294-
raise NotImplementedError("`repeat` is not supported with openvino backend")
1294+
x = get_ov_output(x)
1295+
1296+
if axis is not None and axis < 0:
1297+
axis += len(x.get_partial_shape())
1298+
1299+
if axis is None:
1300+
x = ov_opset.reshape(
1301+
x, ov_opset.constant([-1], Type.i32), special_zero=False
1302+
).output(0)
1303+
axis = 0
1304+
1305+
if isinstance(repeats, (int, np.integer)) or (
1306+
isinstance(repeats, np.ndarray)
1307+
and repeats.ndim == 1
1308+
and repeats.size == 1
1309+
):
1310+
repeats_val = (
1311+
int(repeats) if isinstance(repeats, np.ndarray) else repeats
1312+
)
1313+
input_shape = ov_opset.shape_of(x, Type.i32).output(0)
1314+
dim_len = ov_opset.gather(
1315+
input_shape,
1316+
ov_opset.constant([axis], Type.i32),
1317+
ov_opset.constant(0, Type.i32),
1318+
).output(0)
1319+
dim_len = ov_opset.squeeze(
1320+
dim_len, ov_opset.constant([0], Type.i32)
1321+
).output(0)
1322+
idx_range = ov_opset.range(
1323+
ov_opset.constant(0, Type.i32),
1324+
dim_len,
1325+
ov_opset.constant(1, Type.i32),
1326+
output_type=Type.i32,
1327+
).output(0)
1328+
idx_range = ov_opset.unsqueeze(
1329+
idx_range, ov_opset.constant([1], Type.i32)
1330+
).output(0)
1331+
tiled = ov_opset.tile(
1332+
idx_range, ov_opset.constant([1, repeats_val], Type.i32)
1333+
).output(0)
1334+
idx = ov_opset.reshape(
1335+
tiled, ov_opset.constant([-1], Type.i32), special_zero=False
1336+
).output(0)
1337+
result = ov_opset.gather(
1338+
x, idx, ov_opset.constant(axis, Type.i32)
1339+
).output(0)
1340+
return OpenVINOKerasTensor(result)
1341+
1342+
repeats_tensor = get_ov_output(repeats)
1343+
input_shape = ov_opset.shape_of(x, Type.i32)
1344+
axis_len = ov_opset.gather(
1345+
input_shape,
1346+
ov_opset.constant([axis], Type.i32),
1347+
ov_opset.constant(0, Type.i32),
1348+
)
1349+
axis_len = ov_opset.squeeze(axis_len, ov_opset.constant([0], Type.i32))
1350+
1351+
# cumsum and total output length
1352+
cumsum = ov_opset.cumsum(repeats_tensor, ov_opset.constant(0, Type.i32))
1353+
total = ov_opset.reduce_sum(
1354+
repeats_tensor, ov_opset.constant([0], Type.i32), keep_dims=False
1355+
)
1356+
total = ov_opset.convert(total, Type.i32)
1357+
1358+
# Build output indices [0, 1, ..., total-1]
1359+
out_indices = ov_opset.range(
1360+
ov_opset.constant(0, Type.i32),
1361+
total,
1362+
ov_opset.constant(1, Type.i32),
1363+
output_type=Type.i32,
1364+
)
1365+
1366+
# For each out_index, find which interval it falls in (searchsorted)
1367+
# Equivalent to: sum(out_indices >= cumsum) for each out_index
1368+
cumsum_unsq = ov_opset.unsqueeze(cumsum, ov_opset.constant([0], Type.i32))
1369+
out_indices_unsq = ov_opset.unsqueeze(
1370+
out_indices, ov_opset.constant([1], Type.i32)
1371+
)
1372+
cumsum_unsq = ov_opset.convert(cumsum_unsq, Type.i32)
1373+
mask = ov_opset.greater_equal(out_indices_unsq, cumsum_unsq)
1374+
gather_indices = ov_opset.reduce_sum(
1375+
ov_opset.convert(mask, Type.i32),
1376+
ov_opset.constant([1], Type.i32),
1377+
keep_dims=False,
1378+
)
1379+
1380+
result = ov_opset.gather(
1381+
x, gather_indices, ov_opset.constant(axis, Type.i32)
1382+
).output(0)
1383+
return OpenVINOKerasTensor(result)
12951384

12961385

12971386
def reshape(x, newshape):

0 commit comments

Comments
 (0)