Skip to content

Implement cbrt function in keras.ops #21453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@
from keras.src.ops.numpy import bitwise_xor as bitwise_xor
from keras.src.ops.numpy import blackman as blackman
from keras.src.ops.numpy import broadcast_to as broadcast_to
from keras.src.ops.numpy import cbrt as cbrt
from keras.src.ops.numpy import ceil as ceil
from keras.src.ops.numpy import clip as clip
from keras.src.ops.numpy import concatenate as concatenate
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from keras.src.ops.numpy import bitwise_xor as bitwise_xor
from keras.src.ops.numpy import blackman as blackman
from keras.src.ops.numpy import broadcast_to as broadcast_to
from keras.src.ops.numpy import cbrt as cbrt
from keras.src.ops.numpy import ceil as ceil
from keras.src.ops.numpy import clip as clip
from keras.src.ops.numpy import concatenate as concatenate
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@
from keras.src.ops.numpy import bitwise_xor as bitwise_xor
from keras.src.ops.numpy import blackman as blackman
from keras.src.ops.numpy import broadcast_to as broadcast_to
from keras.src.ops.numpy import cbrt as cbrt
from keras.src.ops.numpy import ceil as ceil
from keras.src.ops.numpy import clip as clip
from keras.src.ops.numpy import concatenate as concatenate
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from keras.src.ops.numpy import bitwise_xor as bitwise_xor
from keras.src.ops.numpy import blackman as blackman
from keras.src.ops.numpy import broadcast_to as broadcast_to
from keras.src.ops.numpy import cbrt as cbrt
from keras.src.ops.numpy import ceil as ceil
from keras.src.ops.numpy import clip as clip
from keras.src.ops.numpy import concatenate as concatenate
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,11 @@ def broadcast_to(x, shape):
return jnp.broadcast_to(x, shape)


def cbrt(x):
x = convert_to_tensor(x)
return jnp.cbrt(x)


@sparse.elementwise_unary(linear=False)
def ceil(x):
x = convert_to_tensor(x)
Expand Down
12 changes: 12 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,18 @@ def broadcast_to(x, shape):
return np.broadcast_to(x, shape)


def cbrt(x):
x = convert_to_tensor(x)

dtype = standardize_dtype(x.dtype)
if dtype in ["bool", "int8", "int16", "int32", "uint8", "uint16", "uint32"]:
dtype = config.floatx()
elif dtype == "int64":
dtype = "float64"

return np.cbrt(x).astype(dtype)


def ceil(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ NumpyDtypeTest::test_hamming
NumpyDtypeTest::test_hanning
NumpyDtypeTest::test_kaiser
NumpyDtypeTest::test_bitwise
NumpyDtypeTest::test_cbrt
NumpyDtypeTest::test_ceil
NumpyDtypeTest::test_concatenate
NumpyDtypeTest::test_corrcoef
Expand Down Expand Up @@ -81,6 +82,7 @@ NumpyOneInputOpsCorrectnessTest::test_hamming
NumpyOneInputOpsCorrectnessTest::test_hanning
NumpyOneInputOpsCorrectnessTest::test_kaiser
NumpyOneInputOpsCorrectnessTest::test_bitwise_invert
NumpyOneInputOpsCorrectnessTest::test_cbrt
NumpyOneInputOpsCorrectnessTest::test_conj
NumpyOneInputOpsCorrectnessTest::test_corrcoef
NumpyOneInputOpsCorrectnessTest::test_correlate
Expand Down Expand Up @@ -153,12 +155,14 @@ NumpyTwoInputOpsCorrectnessTest::test_vdot
NumpyOneInputOpsDynamicShapeTest::test_angle
NumpyOneInputOpsDynamicShapeTest::test_bartlett
NumpyOneInputOpsDynamicShapeTest::test_blackman
NumpyOneInputOpsDynamicShapeTest::test_cbrt
NumpyOneInputOpsDynamicShapeTest::test_corrcoef
NumpyOneInputOpsDynamicShapeTest::test_deg2rad
NumpyOneInputOpsDynamicShapeTest::test_hamming
NumpyOneInputOpsDynamicShapeTest::test_hanning
NumpyOneInputOpsDynamicShapeTest::test_kaiser
NumpyOneInputOpsStaticShapeTest::test_angle
NumpyOneInputOpsStaticShapeTest::test_cbrt
NumpyOneInputOpsStaticShapeTest::test_deg2rad
CoreOpsBehaviorTests::test_associative_scan_invalid_arguments
CoreOpsBehaviorTests::test_scan_invalid_arguments
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,10 @@ def broadcast_to(x, shape):
return OpenVINOKerasTensor(ov_opset.broadcast(x, target_shape).output(0))


def cbrt(x):
raise NotImplementedError("`cbrt` is not supported with openvino backend")


def ceil(x):
x = get_ov_output(x)
return OpenVINOKerasTensor(ov_opset.ceil(x).output(0))
Expand Down
12 changes: 12 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,18 @@ def broadcast_to(x, shape):
return tf.broadcast_to(x, shape)


def cbrt(x):
x = convert_to_tensor(x)

dtype = standardize_dtype(x.dtype)
if dtype == "int64":
x = tf.cast(x, "float64")
elif dtype not in ["bfloat16", "float16", "float64"]:
x = tf.cast(x, config.floatx())

return tf.sign(x) * tf.pow(tf.abs(x), 1.0 / 3.0)


@sparse.elementwise_unary
def ceil(x):
x = convert_to_tensor(x)
Expand Down
12 changes: 12 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,18 @@ def broadcast_to(x, shape):
return torch.broadcast_to(x, shape)


def cbrt(x):
x = convert_to_tensor(x)

dtype = standardize_dtype(x.dtype)
if dtype == "bool":
x = cast(x, "int32")
elif dtype == "int64":
x = cast(x, "float64")

return torch.sign(x) * torch.abs(x) ** (1.0 / 3.0)


def ceil(x):
x = convert_to_tensor(x)
ori_dtype = standardize_dtype(x.dtype)
Expand Down
23 changes: 23 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,29 @@ def broadcast_to(x, shape):
return backend.numpy.broadcast_to(x, shape)


class Cbrt(Operation):
def call(self, x):
return backend.numpy.cbrt(x)


@keras_export(["keras.ops.cbrt", "keras.ops.numpy.cbrt"])
def cbrt(x):
"""Computes the cube root of the input tensor, element-wise.

This operation returns the real-valued cube root of `x`, handling
negative numbers properly in the real domain.

Args:
x: Input tensor.

Returns:
A tensor containing the cube root of each element in `x`.
"""
if any_symbolic_tensors((x,)):
return Cbrt().symbolic_call(x)
return backend.numpy.cbrt(x)


class Ceil(Operation):
def call(self, x):
return backend.numpy.ceil(x)
Expand Down
33 changes: 33 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,10 @@ def test_broadcast_to(self):
x = KerasTensor((3, 3))
knp.broadcast_to(x, (2, 2, 3))

def test_cbrt(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.cbrt(x).shape, (None, 3))

def test_ceil(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.ceil(x).shape, (None, 3))
Expand Down Expand Up @@ -1875,6 +1879,10 @@ def test_broadcast_to(self):
x = KerasTensor((3, 3))
knp.broadcast_to(x, (2, 2, 3))

def test_cbrt(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.cbrt(x).shape, (2, 3))

def test_ceil(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.ceil(x).shape, (2, 3))
Expand Down Expand Up @@ -3737,6 +3745,17 @@ def test_broadcast_to(self):
np.broadcast_to(x, [2, 2, 3]),
)

def test_cbrt(self):
x = np.array([[-8, -1, 0], [1, 8, 27]], dtype="float32")
ref_y = np.sign(x) * np.abs(x) ** (1.0 / 3.0)
y = knp.cbrt(x)
self.assertEqual(standardize_dtype(y.dtype), "float32")
self.assertAllClose(y, ref_y)

y = knp.Cbrt()(x)
self.assertEqual(standardize_dtype(y.dtype), "float32")
self.assertAllClose(y, ref_y)

def test_ceil(self):
x = np.array([[1.2, 2.1, -2.5], [2.4, -11.9, -5.5]])
self.assertAllClose(knp.ceil(x), np.ceil(x))
Expand Down Expand Up @@ -6479,6 +6498,20 @@ def test_broadcast_to(self, dtype):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_cbrt(self, dtype):
import jax.numpy as jnp

x1 = knp.ones((1,), dtype=dtype)
x1_jax = jnp.ones((1,), dtype=dtype)
expected_dtype = standardize_dtype(jnp.cbrt(x1_jax).dtype)

self.assertEqual(standardize_dtype(knp.cbrt(x1).dtype), expected_dtype)
self.assertEqual(
standardize_dtype(knp.Cbrt().symbolic_call(x1).dtype),
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_ceil(self, dtype):
import jax.numpy as jnp
Expand Down