@@ -1291,7 +1291,77 @@ def reciprocal(x):
1291
1291
1292
1292
1293
1293
def repeat (x , repeats , axis = None ):
1294
- raise NotImplementedError ("`repeat` is not supported with openvino backend" )
1294
+ x = get_ov_output (x )
1295
+
1296
+ if axis is not None and axis < 0 :
1297
+ axis += len (x .get_partial_shape ())
1298
+
1299
+ if axis is None :
1300
+ x = ov_opset .reshape (
1301
+ x , ov_opset .constant ([- 1 ], Type .i32 ), special_zero = False
1302
+ ).output (0 )
1303
+ axis = 0
1304
+
1305
+ if isinstance (repeats , (int , np .integer )) or (
1306
+ isinstance (repeats , np .ndarray )
1307
+ and repeats .ndim == 1
1308
+ and repeats .size == 1
1309
+ ):
1310
+ repeats_val = (
1311
+ int (repeats ) if isinstance (repeats , np .ndarray ) else repeats
1312
+ )
1313
+ input_shape = ov_opset .shape_of (x , Type .i32 ).output (0 )
1314
+ dim_len = ov_opset .gather (
1315
+ input_shape ,
1316
+ ov_opset .constant ([axis ], Type .i32 ),
1317
+ ov_opset .constant (0 , Type .i32 ),
1318
+ ).output (0 )
1319
+ dim_len = ov_opset .squeeze (
1320
+ dim_len , ov_opset .constant ([0 ], Type .i32 )
1321
+ ).output (0 )
1322
+ idx_range = ov_opset .range (
1323
+ ov_opset .constant (0 , Type .i32 ),
1324
+ dim_len ,
1325
+ ov_opset .constant (1 , Type .i32 ),
1326
+ output_type = Type .i32 ,
1327
+ ).output (0 )
1328
+ idx_range = ov_opset .unsqueeze (
1329
+ idx_range , ov_opset .constant ([1 ], Type .i32 )
1330
+ ).output (0 )
1331
+ tiled = ov_opset .tile (
1332
+ idx_range , ov_opset .constant ([1 , repeats_val ], Type .i32 )
1333
+ ).output (0 )
1334
+ idx = ov_opset .reshape (
1335
+ tiled , ov_opset .constant ([- 1 ], Type .i32 ), special_zero = False
1336
+ ).output (0 )
1337
+ result = ov_opset .gather (
1338
+ x , idx , ov_opset .constant (axis , Type .i32 )
1339
+ ).output (0 )
1340
+ return OpenVINOKerasTensor (result )
1341
+
1342
+ repeats_np = np .array (repeats )
1343
+ input_shape = ov_opset .shape_of (x , Type .i32 ).output (0 )
1344
+ axis_len_val = x .get_partial_shape ()[axis ]
1345
+
1346
+ # Only check if shape is static
1347
+ axis_len_static = x .get_partial_shape ()[axis ]
1348
+ if axis_len_static .is_static and axis_len_static .get_length () != len (
1349
+ repeats_np
1350
+ ):
1351
+ raise ValueError ("repeats length does not match axis length" )
1352
+
1353
+ gather_indices = np .concatenate (
1354
+ [
1355
+ np .full (r , i , dtype = np .int32 )
1356
+ for i , r in enumerate (repeats_np )
1357
+ if r > 0
1358
+ ]
1359
+ )
1360
+ gather_indices_ov = ov_opset .constant (gather_indices , Type .i32 ).output (0 )
1361
+ result = ov_opset .gather (
1362
+ x , gather_indices_ov , ov_opset .constant (axis , Type .i32 )
1363
+ ).output (0 )
1364
+ return OpenVINOKerasTensor (result )
1295
1365
1296
1366
1297
1367
def reshape (x , newshape ):
0 commit comments