diff --git a/keras/src/testing/test_ops_int_cast.py b/keras/src/testing/test_ops_int_cast.py new file mode 100644 index 000000000000..e402e02adc1d --- /dev/null +++ b/keras/src/testing/test_ops_int_cast.py @@ -0,0 +1,40 @@ +import numpy as np +import pytest + +import keras +from keras import backend +from keras import layers +from keras import ops +from keras.src.utils.arg_casts import _maybe_convert_to_int + + +@pytest.mark.skipif( + backend.backend() in ["numpy", "openvino"], + reason="fit() not implemented for NumPy/OpenVINO backends", +) +def test_dense_accepts_ops_prod_units_and_call_ops_prod(): + class ProdDenseLayer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def build(self, input_shape): + units = ops.prod(input_shape[1:]) + self.dense = layers.Dense(_maybe_convert_to_int(units)) + self.dense.build(input_shape) + + def call(self, inputs): + scale_factor = ops.prod(ops.shape(inputs)[1:]) + scaled_inputs = inputs * ops.cast(scale_factor, inputs.dtype) + return self.dense(scaled_inputs) + + batch_size = 4 + input_shape = (10,) + X_train = np.random.randn(batch_size * 2, *input_shape).astype(np.float32) + y_train = np.random.randint(0, 2, (batch_size * 2, 10)).astype(np.float32) + + inp = keras.Input(shape=input_shape) + out = ProdDenseLayer()(inp) + model = keras.Model(inputs=inp, outputs=out) + + model.compile(optimizer="adam", loss="binary_crossentropy") + model.fit(X_train, y_train, epochs=1, batch_size=batch_size, verbose=0) diff --git a/keras/src/utils/arg_casts.py b/keras/src/utils/arg_casts.py new file mode 100644 index 000000000000..01ae9a70e54b --- /dev/null +++ b/keras/src/utils/arg_casts.py @@ -0,0 +1,36 @@ +from typing import Any + +import numpy as np + +from keras import ops + + +def _maybe_convert_to_int(x: Any) -> Any: + if isinstance(x, int): + return x + if isinstance(x, (tuple, list)): + try: + return tuple(int(v) for v in x) + except Exception: + return x + + try: + np_val = ops.convert_to_numpy(x) + except Exception: + return x + + if np.isscalar(np_val): + try: + return int(np_val) + except Exception: + return x + + arr = np.asarray(np_val).ravel() + if arr.size == 0: + return x + if arr.size == 1: + return int(arr[0]) + try: + return tuple(int(v) for v in arr.tolist()) + except Exception: + return x