@@ -1442,10 +1442,75 @@ def take(x, indices, axis=None):
1442
1442
1443
1443
1444
1444
def take_along_axis (x , indices , axis = None ):
1445
- raise NotImplementedError (
1446
- "`take_along_axis` is not supported with openvino backend"
1445
+ x = get_ov_output (x )
1446
+ indices = get_ov_output (indices )
1447
+
1448
+ if axis is None :
1449
+ target_shape = ov_opset .constant ([- 1 ], dtype = Type .i32 ).output (0 )
1450
+ x_flat = ov_opset .reshape (x , target_shape , False ).output (0 )
1451
+ indices_flat = ov_opset .reshape (indices , target_shape , False ).output (0 )
1452
+ result = ov_opset .gather_elements (x_flat , indices_flat , 0 ).output (0 )
1453
+ return OpenVINOKerasTensor (result )
1454
+
1455
+ x_rank = len (x .get_partial_shape ())
1456
+ if axis < 0 :
1457
+ axis += x_rank
1458
+
1459
+ x_shape = ov_opset .shape_of (x , Type .i32 ).output (0 )
1460
+ indices_shape = ov_opset .shape_of (indices , Type .i32 ).output (0 )
1461
+
1462
+ zero_const = ov_opset .constant (0 , dtype = Type .i32 ).output (0 )
1463
+ axis_index = ov_opset .constant ([axis ], dtype = Type .i32 ).output (0 )
1464
+
1465
+ # Fix negative indices
1466
+ dim_size = ov_opset .squeeze (
1467
+ ov_opset .gather (x_shape , axis_index , zero_const ).output (0 ), zero_const
1468
+ ).output (0 )
1469
+ zero_scalar = ov_opset .constant (0 , indices .get_element_type ()).output (0 )
1470
+ is_neg = ov_opset .less (indices , zero_scalar ).output (0 )
1471
+ dim_size_cast = ov_opset .convert (
1472
+ dim_size , indices .get_element_type ()
1473
+ ).output (0 )
1474
+ indices = ov_opset .select (
1475
+ is_neg , ov_opset .add (indices , dim_size_cast ).output (0 ), indices
1476
+ ).output (0 )
1477
+ indices = ov_opset .convert (indices , Type .i32 ).output (0 )
1478
+
1479
+ x_target_parts , indices_target_parts = [], []
1480
+
1481
+ for i in range (x_rank ):
1482
+ dim_idx = ov_opset .constant ([i ], dtype = Type .i32 ).output (0 )
1483
+ x_dim = ov_opset .gather (x_shape , dim_idx , zero_const ).output (0 )
1484
+ indices_dim = ov_opset .gather (
1485
+ indices_shape , dim_idx , zero_const
1486
+ ).output (0 )
1487
+
1488
+ if i == axis :
1489
+ # For axis dimension: keep original dimensions
1490
+ x_target_parts .append (x_dim )
1491
+ indices_target_parts .append (indices_dim )
1492
+ else :
1493
+ # For other dimensions: use maximum for broadcasting
1494
+ max_dim = ov_opset .maximum (x_dim , indices_dim ).output (0 )
1495
+ x_target_parts .append (max_dim )
1496
+ indices_target_parts .append (max_dim )
1497
+
1498
+ x_target_shape = ov_opset .concat (x_target_parts , axis = 0 ).output (0 )
1499
+ indices_target_shape = ov_opset .concat (indices_target_parts , axis = 0 ).output (
1500
+ 0
1447
1501
)
1448
1502
1503
+ # Broadcast to target shapes and gather elements
1504
+ x_broadcasted = ov_opset .broadcast (x , x_target_shape ).output (0 )
1505
+ indices_broadcasted = ov_opset .broadcast (
1506
+ indices , indices_target_shape
1507
+ ).output (0 )
1508
+ result = ov_opset .gather_elements (
1509
+ x_broadcasted , indices_broadcasted , axis
1510
+ ).output (0 )
1511
+
1512
+ return OpenVINOKerasTensor (result )
1513
+
1449
1514
1450
1515
def tan (x ):
1451
1516
x = get_ov_output (x )
0 commit comments