Skip to content

Inconsistent behaviours between backends on ops.tile #20914

@yshao-aim-solutions

Description

@yshao-aim-solutions

I have the following function to broadcast two arrays to compute the multiplication for all possible mutations. In this function I used both tile and repeat function and I found tile shows inconsistent behaviours between backends.

  • Jax:
(<KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_2>, <KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_3>)
(<KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_4>, <KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_5>)
  • TensorFlow:
(<KerasTensor shape=(None, None, None, None), dtype=float32, sparse=False, name=keras_tensor_2>, <KerasTensor shape=(None, None, None, None), dtype=float32, sparse=False, name=keras_tensor_3>)
(<KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_4>, <KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_5>)

It seems that TensorFlow could not properly infer the shape of the resulting symbolic tensor.

Another issue is that when using repeats for tile based on the shape of symbolic tensor, TensorFlow still works (although with shapes all None), but Jax raises an error: "'str' object has no attribute '_error_repr'". This issue can be reproduced by replacing repeats with the commented command.

Reproduction Code

import os
os.environ["KERAS_BACKEND"] = "jax"

from keras import ops, layers
from keras import Input

# %%
def broadcast(x1, x2):
    """
    Broadcast the shapes of x1 and x2 to allow the computation of cross-interation. 
    
    - repeating input1: (a, b) -> (a, b, a, b)
    - repeating input2: (c, d) -> (c, c, d, d)
    - result: (a, b) * (c, d) = (a * c, b * c, a * d, b * d)
    
    Args:
        x1: nD array in shape (..., n1, ny1, nx1) to be broadcasted
        x2: nD array in shape (..., n2, ny2, nx2) to be broadcasted

    Returns:
        Broadcasted nD arrays in shape (..., n1 * n2, ...)

    Examples:
        >>> import numpy as np
        >>> x1 = np.array([[[[0., 1., 2.]],
        ...                 [[3., 4., 5.]]]])
        >>> x2 = np.array([[[[0., 1., 2.]],
        ...                 [[3., 4., 5.]]]])
        >>> x1, x2 = broadcast((x1, np.zeros(np.shape(x1))), (x2, np.zeros(np.shape(x2))))
       
        >>> np.array(x1[0]) + 1j * np.array(x1[1])
        array([[[[0.+0.j, 1.+0.j, 2.+0.j]],
                [[3.+0.j, 4.+0.j, 5.+0.j]],
                [[0.+0.j, 1.+0.j, 2.+0.j]],
                [[3.+0.j, 4.+0.j, 5.+0.j]]]], dtype=complex64)
        >>> np.array(x2[0]) + 1j * np.array(x2[1])
        array([[[[0.+0.j, 1.+0.j, 2.+0.j]],
                [[0.+0.j, 1.+0.j, 2.+0.j]],
                [[3.+0.j, 4.+0.j, 5.+0.j]],
                [[3.+0.j, 4.+0.j, 5.+0.j]]]], dtype=complex64)
    """

    x1real, x1imag = x1
    x2real, x2imag = x2
    
    x1shape = ops.shape(x1real)[-3] # spatial mode dimension
    x2shape = ops.shape(x2real)[-3] # spatial mode dimension
    
    x1dims = len(ops.shape(x1real))
    # repeats = ops.scatter_update(ops.cast(ops.ones(x1dims), dtype="int32"), [[-3 + x1dims]], [x2shape])
    repeats = [1, 2, 1, 1]
    x1real = ops.tile(x1real, repeats)
    x1imag = ops.tile(x1imag, repeats)

    x2real = ops.repeat(x2real, x1shape, axis=-3)
    x2imag = ops.repeat(x2imag, x1shape, axis=-3)

    return ((x1real, x1imag), (x2real, x2imag))

# %%

class Test(layers.Layer):
    
    def call(self, inputs1, inputs2):

        return broadcast(inputs1, inputs2)

test = Test()

# %%

x1 = Input(shape=(3, 2, 2))
x2 = Input(shape=(2, 2, 2))

y1, y2 = test((x1, x1), (x2, x2))

print(y1)
print(y2)

Environment
jax 0.5.0
jaxlib 0.5.0
keras 3.8.0
tensorboard 2.18.0

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions