Skip to content

Commit ce95589

Browse files
authored
Make take_along_axis with TF backend compilable. (#21239)
When there are dynamic dimensions, like typically the batch size, `tf.broadcast_dynamic_shape` is not always compilable. Replace with an adhoc implementation for dynamic dimensions where we rely on the broadcast itself to fail when the shapes are not broadcastable. Tested with https://github.com/keras-team/keras-rs/blob/main/examples/listwise_ranking.py on GPU as I was not able to distill a simple reproduction of this.
1 parent 4c6b7e3 commit ce95589

File tree

1 file changed

+34
-18
lines changed

1 file changed

+34
-18
lines changed

keras/src/backend/tensorflow/numpy.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2260,9 +2260,7 @@ def fix_negative_indices(i):
22602260

22612261

22622262
def take_along_axis(x, indices, axis=None):
2263-
from keras.src.ops.operation_utils import (
2264-
compute_take_along_axis_output_shape,
2265-
)
2263+
from keras.src.ops import operation_utils
22662264

22672265
x = convert_to_tensor(x)
22682266
indices = convert_to_tensor(indices, "int64")
@@ -2276,28 +2274,46 @@ def take_along_axis(x, indices, axis=None):
22762274

22772275
# Compute the static output shape as later on, all shapes manipulations
22782276
# use dynamic shapes.
2279-
static_output_shape = compute_take_along_axis_output_shape(
2277+
static_output_shape = operation_utils.compute_take_along_axis_output_shape(
22802278
x.shape, indices.shape, axis
22812279
)
22822280
rank = x.ndim
22832281
static_axis = axis
22842282
axis = axis + rank if axis < 0 else axis
22852283

2286-
# Broadcast shapes to match, ensure that the axis of interest is not
2287-
# broadcast.
2288-
x_shape_original = tf.shape(x, out_type=indices.dtype)
2289-
indices_shape_original = tf.shape(indices, out_type=indices.dtype)
2290-
x_shape = tf.tensor_scatter_nd_update(x_shape_original, [[axis]], [1])
2291-
indices_shape = tf.tensor_scatter_nd_update(
2292-
indices_shape_original, [[axis]], [1]
2293-
)
2294-
broadcasted_shape = tf.broadcast_dynamic_shape(x_shape, indices_shape)
2295-
x_shape = tf.tensor_scatter_nd_update(
2296-
broadcasted_shape, [[axis]], [x_shape_original[axis]]
2297-
)
2298-
indices_shape = tf.tensor_scatter_nd_update(
2299-
broadcasted_shape, [[axis]], [indices_shape_original[axis]]
2284+
if axis >= rank:
2285+
raise ValueError(f"Invalid axis: {static_axis} for input rank: {rank}")
2286+
2287+
x_original_shape = shape_op(x)
2288+
indices_original_shape = shape_op(indices)
2289+
2290+
# Broadcast the static shapes first, but not for the `axis` dimension.
2291+
x_static_shape = list(x.shape)
2292+
indices_static_shape = list(indices.shape)
2293+
x_static_shape[axis] = 1
2294+
indices_static_shape[axis] = 1
2295+
broadcast_shape = operation_utils.broadcast_shapes(
2296+
x_static_shape, indices_static_shape
23002297
)
2298+
2299+
if None in broadcast_shape:
2300+
# Dynamic broadcast case. Note that `tf.broadcast_dynamic_shape` is
2301+
# not always XLA compilable with dynamic dimensions.
2302+
# We replace `None`s with the dynamic dimensions.
2303+
# `maximum` is the correct formula only when shapes are broadcastable,
2304+
# we rely on the broacast itself to fail in the incorrect case rather
2305+
# than make some expensive dynamic checks here.
2306+
broadcast_shape = [
2307+
tf.maximum(x_original_shape[i], indices_original_shape[i])
2308+
if dim is None
2309+
else dim
2310+
for i, dim in enumerate(broadcast_shape)
2311+
]
2312+
2313+
x_shape = list(broadcast_shape)
2314+
x_shape[axis] = x_original_shape[axis]
2315+
indices_shape = list(broadcast_shape)
2316+
indices_shape[axis] = indices_original_shape[axis]
23012317
x = tf.broadcast_to(x, x_shape)
23022318
indices = tf.broadcast_to(indices, indices_shape)
23032319

0 commit comments

Comments
 (0)