Skip to content

Commit 3ff9ca3

Browse files
[OpenVINO backend] support repeat
1 parent be9b002 commit 3ff9ca3

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ NumpyDtypeTest::test_multiply
4545
NumpyDtypeTest::test_power
4646
NumpyDtypeTest::test_prod
4747
NumpyDtypeTest::test_quantile
48-
NumpyDtypeTest::test_repeat
4948
NumpyDtypeTest::test_roll
5049
NumpyDtypeTest::test_round
5150
NumpyDtypeTest::test_searchsorted
@@ -106,7 +105,6 @@ NumpyOneInputOpsCorrectnessTest::test_pad_uint8_constant_2
106105
NumpyOneInputOpsCorrectnessTest::test_pad_int32_constant_2
107106
NumpyOneInputOpsCorrectnessTest::test_prod
108107
NumpyOneInputOpsCorrectnessTest::test_real
109-
NumpyOneInputOpsCorrectnessTest::test_repeat
110108
NumpyOneInputOpsCorrectnessTest::test_reshape
111109
NumpyOneInputOpsCorrectnessTest::test_roll
112110
NumpyOneInputOpsCorrectnessTest::test_round

keras/src/backend/openvino/numpy.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1291,8 +1291,48 @@ def reciprocal(x):
12911291

12921292

12931293
def repeat(x, repeats, axis=None):
1294-
raise NotImplementedError("`repeat` is not supported with openvino backend")
1294+
x = get_ov_output(x)
1295+
1296+
if axis is not None and axis < 0:
1297+
axis += len(x.get_partial_shape())
12951298

1299+
if axis is None:
1300+
x = ov_opset.reshape(x, ov_opset.constant([-1], Type.i32), special_zero=False).output(0)
1301+
axis = 0
1302+
1303+
if isinstance(repeats, (int, np.integer)) or (
1304+
isinstance(repeats, np.ndarray) and repeats.ndim == 1 and repeats.size == 1
1305+
):
1306+
repeats_val = int(repeats) if isinstance(repeats, np.ndarray) else repeats
1307+
input_shape = ov_opset.shape_of(x, Type.i32).output(0)
1308+
dim_len = ov_opset.gather(input_shape, ov_opset.constant([axis], Type.i32), ov_opset.constant(0, Type.i32)).output(0)
1309+
dim_len = ov_opset.squeeze(dim_len, ov_opset.constant([0], Type.i32)).output(0)
1310+
idx_range = ov_opset.range(
1311+
ov_opset.constant(0, Type.i32),
1312+
dim_len,
1313+
ov_opset.constant(1, Type.i32),
1314+
output_type=Type.i32,
1315+
).output(0)
1316+
idx_range = ov_opset.unsqueeze(idx_range, ov_opset.constant([1], Type.i32)).output(0)
1317+
tiled = ov_opset.tile(idx_range, ov_opset.constant([1, repeats_val], Type.i32)).output(0)
1318+
idx = ov_opset.reshape(tiled, ov_opset.constant([-1], Type.i32), special_zero=False).output(0)
1319+
result = ov_opset.gather(x, idx, ov_opset.constant(axis, Type.i32)).output(0)
1320+
return OpenVINOKerasTensor(result)
1321+
1322+
repeats_np = np.array(repeats)
1323+
if repeats_np.ndim != 1:
1324+
raise NotImplementedError("Only 1D repeats arrays are supported.")
1325+
input_shape = ov_opset.shape_of(x, Type.i32).output(0)
1326+
axis_len = ov_opset.gather(input_shape, ov_opset.constant([axis], Type.i32), ov_opset.constant(0, Type.i32)).output(0)
1327+
axis_len_val = int(axis_len.get_vector()[0]) if hasattr(axis_len, "get_vector") else x.get_partial_shape()[axis]
1328+
if axis_len_val != len(repeats_np):
1329+
raise ValueError("repeats length does not match axis length")
1330+
1331+
gather_indices = np.concatenate([np.full(r, i, dtype=np.int32) for i, r in enumerate(repeats_np) if r > 0])
1332+
gather_indices_ov = ov_opset.constant(gather_indices, Type.i32).output(0)
1333+
result = ov_opset.gather(x, gather_indices_ov, ov_opset.constant(axis, Type.i32)).output(0)
1334+
return OpenVINOKerasTensor(result)
1335+
12961336

12971337
def reshape(x, newshape):
12981338
x = get_ov_output(x)

0 commit comments

Comments
 (0)