diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 4c46eac1e524..7cb93228eacd 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -149,6 +149,7 @@ from keras.src.ops.numpy import bitwise_xor as bitwise_xor from keras.src.ops.numpy import blackman as blackman from keras.src.ops.numpy import broadcast_to as broadcast_to +from keras.src.ops.numpy import cbrt as cbrt from keras.src.ops.numpy import ceil as ceil from keras.src.ops.numpy import clip as clip from keras.src.ops.numpy import concatenate as concatenate diff --git a/keras/api/_tf_keras/keras/ops/numpy/__init__.py b/keras/api/_tf_keras/keras/ops/numpy/__init__.py index 49c3e9eca6c5..78f51f759988 100644 --- a/keras/api/_tf_keras/keras/ops/numpy/__init__.py +++ b/keras/api/_tf_keras/keras/ops/numpy/__init__.py @@ -38,6 +38,7 @@ from keras.src.ops.numpy import bitwise_xor as bitwise_xor from keras.src.ops.numpy import blackman as blackman from keras.src.ops.numpy import broadcast_to as broadcast_to +from keras.src.ops.numpy import cbrt as cbrt from keras.src.ops.numpy import ceil as ceil from keras.src.ops.numpy import clip as clip from keras.src.ops.numpy import concatenate as concatenate diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 4c46eac1e524..7cb93228eacd 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -149,6 +149,7 @@ from keras.src.ops.numpy import bitwise_xor as bitwise_xor from keras.src.ops.numpy import blackman as blackman from keras.src.ops.numpy import broadcast_to as broadcast_to +from keras.src.ops.numpy import cbrt as cbrt from keras.src.ops.numpy import ceil as ceil from keras.src.ops.numpy import clip as clip from keras.src.ops.numpy import concatenate as concatenate diff --git a/keras/api/ops/numpy/__init__.py b/keras/api/ops/numpy/__init__.py index 49c3e9eca6c5..78f51f759988 100644 --- a/keras/api/ops/numpy/__init__.py +++ b/keras/api/ops/numpy/__init__.py @@ -38,6 +38,7 @@ from keras.src.ops.numpy import bitwise_xor as bitwise_xor from keras.src.ops.numpy import blackman as blackman from keras.src.ops.numpy import broadcast_to as broadcast_to +from keras.src.ops.numpy import cbrt as cbrt from keras.src.ops.numpy import ceil as ceil from keras.src.ops.numpy import clip as clip from keras.src.ops.numpy import concatenate as concatenate diff --git a/keras/src/backend/jax/numpy.py b/keras/src/backend/jax/numpy.py index 3a7c5b61bb88..530b0a614769 100644 --- a/keras/src/backend/jax/numpy.py +++ b/keras/src/backend/jax/numpy.py @@ -499,6 +499,11 @@ def broadcast_to(x, shape): return jnp.broadcast_to(x, shape) +def cbrt(x): + x = convert_to_tensor(x) + return jnp.cbrt(x) + + @sparse.elementwise_unary(linear=False) def ceil(x): x = convert_to_tensor(x) diff --git a/keras/src/backend/numpy/numpy.py b/keras/src/backend/numpy/numpy.py index 6f33a4bf3c86..c5e701ab2c6c 100644 --- a/keras/src/backend/numpy/numpy.py +++ b/keras/src/backend/numpy/numpy.py @@ -414,6 +414,18 @@ def broadcast_to(x, shape): return np.broadcast_to(x, shape) +def cbrt(x): + x = convert_to_tensor(x) + + dtype = standardize_dtype(x.dtype) + if dtype in ["bool", "int8", "int16", "int32", "uint8", "uint16", "uint32"]: + dtype = config.floatx() + elif dtype == "int64": + dtype = "float64" + + return np.cbrt(x).astype(dtype) + + def ceil(x): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": diff --git a/keras/src/backend/openvino/excluded_concrete_tests.txt b/keras/src/backend/openvino/excluded_concrete_tests.txt index f370625284c8..61209de61060 100644 --- a/keras/src/backend/openvino/excluded_concrete_tests.txt +++ b/keras/src/backend/openvino/excluded_concrete_tests.txt @@ -14,6 +14,7 @@ NumpyDtypeTest::test_hamming NumpyDtypeTest::test_hanning NumpyDtypeTest::test_kaiser NumpyDtypeTest::test_bitwise +NumpyDtypeTest::test_cbrt NumpyDtypeTest::test_ceil NumpyDtypeTest::test_concatenate NumpyDtypeTest::test_corrcoef @@ -81,6 +82,7 @@ NumpyOneInputOpsCorrectnessTest::test_hamming NumpyOneInputOpsCorrectnessTest::test_hanning NumpyOneInputOpsCorrectnessTest::test_kaiser NumpyOneInputOpsCorrectnessTest::test_bitwise_invert +NumpyOneInputOpsCorrectnessTest::test_cbrt NumpyOneInputOpsCorrectnessTest::test_conj NumpyOneInputOpsCorrectnessTest::test_corrcoef NumpyOneInputOpsCorrectnessTest::test_correlate @@ -153,12 +155,14 @@ NumpyTwoInputOpsCorrectnessTest::test_vdot NumpyOneInputOpsDynamicShapeTest::test_angle NumpyOneInputOpsDynamicShapeTest::test_bartlett NumpyOneInputOpsDynamicShapeTest::test_blackman +NumpyOneInputOpsDynamicShapeTest::test_cbrt NumpyOneInputOpsDynamicShapeTest::test_corrcoef NumpyOneInputOpsDynamicShapeTest::test_deg2rad NumpyOneInputOpsDynamicShapeTest::test_hamming NumpyOneInputOpsDynamicShapeTest::test_hanning NumpyOneInputOpsDynamicShapeTest::test_kaiser NumpyOneInputOpsStaticShapeTest::test_angle +NumpyOneInputOpsStaticShapeTest::test_cbrt NumpyOneInputOpsStaticShapeTest::test_deg2rad CoreOpsBehaviorTests::test_associative_scan_invalid_arguments CoreOpsBehaviorTests::test_scan_invalid_arguments diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index 4f0a3fedd12e..a5b60451a2bc 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -545,6 +545,10 @@ def broadcast_to(x, shape): return OpenVINOKerasTensor(ov_opset.broadcast(x, target_shape).output(0)) +def cbrt(x): + raise NotImplementedError("`cbrt` is not supported with openvino backend") + + def ceil(x): x = get_ov_output(x) return OpenVINOKerasTensor(ov_opset.ceil(x).output(0)) diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index 1bee43030565..b1f9ef211618 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -1089,6 +1089,18 @@ def broadcast_to(x, shape): return tf.broadcast_to(x, shape) +def cbrt(x): + x = convert_to_tensor(x) + + dtype = standardize_dtype(x.dtype) + if dtype == "int64": + x = tf.cast(x, "float64") + elif dtype not in ["bfloat16", "float16", "float64"]: + x = tf.cast(x, config.floatx()) + + return tf.sign(x) * tf.pow(tf.abs(x), 1.0 / 3.0) + + @sparse.elementwise_unary def ceil(x): x = convert_to_tensor(x) diff --git a/keras/src/backend/torch/numpy.py b/keras/src/backend/torch/numpy.py index a7699ac6746e..b6e499754cc5 100644 --- a/keras/src/backend/torch/numpy.py +++ b/keras/src/backend/torch/numpy.py @@ -540,6 +540,18 @@ def broadcast_to(x, shape): return torch.broadcast_to(x, shape) +def cbrt(x): + x = convert_to_tensor(x) + + dtype = standardize_dtype(x.dtype) + if dtype == "bool": + x = cast(x, "int32") + elif dtype == "int64": + x = cast(x, "float64") + + return torch.sign(x) * torch.abs(x) ** (1.0 / 3.0) + + def ceil(x): x = convert_to_tensor(x) ori_dtype = standardize_dtype(x.dtype) diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index bb9525be2203..b889bf25fe89 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -1779,6 +1779,29 @@ def broadcast_to(x, shape): return backend.numpy.broadcast_to(x, shape) +class Cbrt(Operation): + def call(self, x): + return backend.numpy.cbrt(x) + + +@keras_export(["keras.ops.cbrt", "keras.ops.numpy.cbrt"]) +def cbrt(x): + """Computes the cube root of the input tensor, element-wise. + + This operation returns the real-valued cube root of `x`, handling + negative numbers properly in the real domain. + + Args: + x: Input tensor. + + Returns: + A tensor containing the cube root of each element in `x`. + """ + if any_symbolic_tensors((x,)): + return Cbrt().symbolic_call(x) + return backend.numpy.cbrt(x) + + class Ceil(Operation): def call(self, x): return backend.numpy.ceil(x) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 14b30d82ced2..95a8ae1d4782 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -1255,6 +1255,10 @@ def test_broadcast_to(self): x = KerasTensor((3, 3)) knp.broadcast_to(x, (2, 2, 3)) + def test_cbrt(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.cbrt(x).shape, (None, 3)) + def test_ceil(self): x = KerasTensor((None, 3)) self.assertEqual(knp.ceil(x).shape, (None, 3)) @@ -1875,6 +1879,10 @@ def test_broadcast_to(self): x = KerasTensor((3, 3)) knp.broadcast_to(x, (2, 2, 3)) + def test_cbrt(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.cbrt(x).shape, (2, 3)) + def test_ceil(self): x = KerasTensor((2, 3)) self.assertEqual(knp.ceil(x).shape, (2, 3)) @@ -3737,6 +3745,17 @@ def test_broadcast_to(self): np.broadcast_to(x, [2, 2, 3]), ) + def test_cbrt(self): + x = np.array([[-8, -1, 0], [1, 8, 27]], dtype="float32") + ref_y = np.sign(x) * np.abs(x) ** (1.0 / 3.0) + y = knp.cbrt(x) + self.assertEqual(standardize_dtype(y.dtype), "float32") + self.assertAllClose(y, ref_y) + + y = knp.Cbrt()(x) + self.assertEqual(standardize_dtype(y.dtype), "float32") + self.assertAllClose(y, ref_y) + def test_ceil(self): x = np.array([[1.2, 2.1, -2.5], [2.4, -11.9, -5.5]]) self.assertAllClose(knp.ceil(x), np.ceil(x)) @@ -6479,6 +6498,20 @@ def test_broadcast_to(self, dtype): expected_dtype, ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_cbrt(self, dtype): + import jax.numpy as jnp + + x1 = knp.ones((1,), dtype=dtype) + x1_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.cbrt(x1_jax).dtype) + + self.assertEqual(standardize_dtype(knp.cbrt(x1).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Cbrt().symbolic_call(x1).dtype), + expected_dtype, + ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) def test_ceil(self, dtype): import jax.numpy as jnp