Skip to content

Commit 2909693

Browse files
[OpenVINO backend] support take_along_axis
1 parent be9b002 commit 2909693

File tree

2 files changed

+67
-4
lines changed

2 files changed

+67
-4
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ NumpyDtypeTest::test_std
5656
NumpyDtypeTest::test_subtract
5757
NumpyDtypeTest::test_sum
5858
NumpyDtypeTest::test_swapaxes
59-
NumpyDtypeTest::test_take_along_axis
6059
NumpyDtypeTest::test_tensordot_
6160
NumpyDtypeTest::test_tile
6261
NumpyDtypeTest::test_trace
@@ -145,7 +144,6 @@ NumpyTwoInputOpsCorrectnessTest::test_inner
145144
NumpyTwoInputOpsCorrectnessTest::test_linspace
146145
NumpyTwoInputOpsCorrectnessTest::test_logspace
147146
NumpyTwoInputOpsCorrectnessTest::test_quantile
148-
NumpyTwoInputOpsCorrectnessTest::test_take_along_axis
149147
NumpyTwoInputOpsCorrectnessTest::test_tensordot
150148
NumpyTwoInputOpsCorrectnessTest::test_vdot
151149
NumpyOneInputOpsDynamicShapeTest::test_angle

keras/src/backend/openvino/numpy.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,10 +1442,75 @@ def take(x, indices, axis=None):
14421442

14431443

14441444
def take_along_axis(x, indices, axis=None):
1445-
raise NotImplementedError(
1446-
"`take_along_axis` is not supported with openvino backend"
1445+
x = get_ov_output(x)
1446+
indices = get_ov_output(indices)
1447+
1448+
if axis is None:
1449+
target_shape = ov_opset.constant([-1], dtype=Type.i32).output(0)
1450+
x_flat = ov_opset.reshape(x, target_shape, False).output(0)
1451+
indices_flat = ov_opset.reshape(indices, target_shape, False).output(0)
1452+
result = ov_opset.gather_elements(x_flat, indices_flat, 0).output(0)
1453+
return OpenVINOKerasTensor(result)
1454+
1455+
x_rank = len(x.get_partial_shape())
1456+
if axis < 0:
1457+
axis += x_rank
1458+
1459+
x_shape = ov_opset.shape_of(x, Type.i32).output(0)
1460+
indices_shape = ov_opset.shape_of(indices, Type.i32).output(0)
1461+
1462+
zero_const = ov_opset.constant(0, dtype=Type.i32).output(0)
1463+
axis_index = ov_opset.constant([axis], dtype=Type.i32).output(0)
1464+
1465+
# Fix negative indices
1466+
dim_size = ov_opset.squeeze(
1467+
ov_opset.gather(x_shape, axis_index, zero_const).output(0), zero_const
1468+
).output(0)
1469+
zero_scalar = ov_opset.constant(0, indices.get_element_type()).output(0)
1470+
is_neg = ov_opset.less(indices, zero_scalar).output(0)
1471+
dim_size_cast = ov_opset.convert(
1472+
dim_size, indices.get_element_type()
1473+
).output(0)
1474+
indices = ov_opset.select(
1475+
is_neg, ov_opset.add(indices, dim_size_cast).output(0), indices
1476+
).output(0)
1477+
indices = ov_opset.convert(indices, Type.i32).output(0)
1478+
1479+
x_target_parts, indices_target_parts = [], []
1480+
1481+
for i in range(x_rank):
1482+
dim_idx = ov_opset.constant([i], dtype=Type.i32).output(0)
1483+
x_dim = ov_opset.gather(x_shape, dim_idx, zero_const).output(0)
1484+
indices_dim = ov_opset.gather(
1485+
indices_shape, dim_idx, zero_const
1486+
).output(0)
1487+
1488+
if i == axis:
1489+
# For axis dimension: keep original dimensions
1490+
x_target_parts.append(x_dim)
1491+
indices_target_parts.append(indices_dim)
1492+
else:
1493+
# For other dimensions: use maximum for broadcasting
1494+
max_dim = ov_opset.maximum(x_dim, indices_dim).output(0)
1495+
x_target_parts.append(max_dim)
1496+
indices_target_parts.append(max_dim)
1497+
1498+
x_target_shape = ov_opset.concat(x_target_parts, axis=0).output(0)
1499+
indices_target_shape = ov_opset.concat(indices_target_parts, axis=0).output(
1500+
0
14471501
)
14481502

1503+
# Broadcast to target shapes and gather elements
1504+
x_broadcasted = ov_opset.broadcast(x, x_target_shape).output(0)
1505+
indices_broadcasted = ov_opset.broadcast(
1506+
indices, indices_target_shape
1507+
).output(0)
1508+
result = ov_opset.gather_elements(
1509+
x_broadcasted, indices_broadcasted, axis
1510+
).output(0)
1511+
1512+
return OpenVINOKerasTensor(result)
1513+
14491514

14501515
def tan(x):
14511516
x = get_ov_output(x)

0 commit comments

Comments
 (0)