@@ -2260,9 +2260,7 @@ def fix_negative_indices(i):
2260
2260
2261
2261
2262
2262
def take_along_axis (x , indices , axis = None ):
2263
- from keras .src .ops .operation_utils import (
2264
- compute_take_along_axis_output_shape ,
2265
- )
2263
+ from keras .src .ops import operation_utils
2266
2264
2267
2265
x = convert_to_tensor (x )
2268
2266
indices = convert_to_tensor (indices , "int64" )
@@ -2276,28 +2274,46 @@ def take_along_axis(x, indices, axis=None):
2276
2274
2277
2275
# Compute the static output shape as later on, all shapes manipulations
2278
2276
# use dynamic shapes.
2279
- static_output_shape = compute_take_along_axis_output_shape (
2277
+ static_output_shape = operation_utils . compute_take_along_axis_output_shape (
2280
2278
x .shape , indices .shape , axis
2281
2279
)
2282
2280
rank = x .ndim
2283
2281
static_axis = axis
2284
2282
axis = axis + rank if axis < 0 else axis
2285
2283
2286
- # Broadcast shapes to match, ensure that the axis of interest is not
2287
- # broadcast.
2288
- x_shape_original = tf .shape (x , out_type = indices .dtype )
2289
- indices_shape_original = tf .shape (indices , out_type = indices .dtype )
2290
- x_shape = tf .tensor_scatter_nd_update (x_shape_original , [[axis ]], [1 ])
2291
- indices_shape = tf .tensor_scatter_nd_update (
2292
- indices_shape_original , [[axis ]], [1 ]
2293
- )
2294
- broadcasted_shape = tf .broadcast_dynamic_shape (x_shape , indices_shape )
2295
- x_shape = tf .tensor_scatter_nd_update (
2296
- broadcasted_shape , [[axis ]], [x_shape_original [axis ]]
2297
- )
2298
- indices_shape = tf .tensor_scatter_nd_update (
2299
- broadcasted_shape , [[axis ]], [indices_shape_original [axis ]]
2284
+ if axis >= rank :
2285
+ raise ValueError (f"Invalid axis: { static_axis } for input rank: { rank } " )
2286
+
2287
+ x_original_shape = shape_op (x )
2288
+ indices_original_shape = shape_op (indices )
2289
+
2290
+ # Broadcast the static shapes first, but not for the `axis` dimension.
2291
+ x_static_shape = list (x .shape )
2292
+ indices_static_shape = list (indices .shape )
2293
+ x_static_shape [axis ] = 1
2294
+ indices_static_shape [axis ] = 1
2295
+ broadcast_shape = operation_utils .broadcast_shapes (
2296
+ x_static_shape , indices_static_shape
2300
2297
)
2298
+
2299
+ if None in broadcast_shape :
2300
+ # Dynamic broadcast case. Note that `tf.broadcast_dynamic_shape` is
2301
+ # not always XLA compilable with dynamic dimensions.
2302
+ # We replace `None`s with the dynamic dimensions.
2303
+ # `maximum` is the correct formula only when shapes are broadcastable,
2304
+ # we rely on the broacast itself to fail in the incorrect case rather
2305
+ # than make some expensive dynamic checks here.
2306
+ broadcast_shape = [
2307
+ tf .maximum (x_original_shape [i ], indices_original_shape [i ])
2308
+ if dim is None
2309
+ else dim
2310
+ for i , dim in enumerate (broadcast_shape )
2311
+ ]
2312
+
2313
+ x_shape = list (broadcast_shape )
2314
+ x_shape [axis ] = x_original_shape [axis ]
2315
+ indices_shape = list (broadcast_shape )
2316
+ indices_shape [axis ] = indices_original_shape [axis ]
2301
2317
x = tf .broadcast_to (x , x_shape )
2302
2318
indices = tf .broadcast_to (indices , indices_shape )
2303
2319
0 commit comments