Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions keras/src/testing/test_ops_int_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import numpy as np
import pytest

import keras
from keras import backend
from keras import layers
from keras import ops
from keras.src.utils.arg_casts import _maybe_convert_to_int


@pytest.mark.skipif(
backend.backend() in ["numpy", "openvino"],
reason="fit() not implemented for NumPy/OpenVINO backends",
)
def test_dense_accepts_ops_prod_units_and_call_ops_prod():
class ProdDenseLayer(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def build(self, input_shape):
units = ops.prod(input_shape[1:])
self.dense = layers.Dense(_maybe_convert_to_int(units))
self.dense.build(input_shape)

def call(self, inputs):
scale_factor = ops.prod(ops.shape(inputs)[1:])
scaled_inputs = inputs * ops.cast(scale_factor, inputs.dtype)
return self.dense(scaled_inputs)

batch_size = 4
input_shape = (10,)
X_train = np.random.randn(batch_size * 2, *input_shape).astype(np.float32)
y_train = np.random.randint(0, 2, (batch_size * 2, 10)).astype(np.float32)

inp = keras.Input(shape=input_shape)
out = ProdDenseLayer()(inp)
model = keras.Model(inputs=inp, outputs=out)

model.compile(optimizer="adam", loss="binary_crossentropy")
model.fit(X_train, y_train, epochs=1, batch_size=batch_size, verbose=0)
36 changes: 36 additions & 0 deletions keras/src/utils/arg_casts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Any

import numpy as np

from keras import ops


def _maybe_convert_to_int(x: Any) -> Any:
if isinstance(x, int):
return x
if isinstance(x, (tuple, list)):
try:
return tuple(int(v) for v in x)
except Exception:
return x

try:
np_val = ops.convert_to_numpy(x)
except Exception:
return x

if np.isscalar(np_val):
try:
return int(np_val)
except Exception:
return x

arr = np.asarray(np_val).ravel()
if arr.size == 0:
return x
if arr.size == 1:
return int(arr[0])
try:
return tuple(int(v) for v in arr.tolist())
except Exception:
return x
Comment on lines +8 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function can be made more robust and slightly cleaner. Using except Exception for int() conversions is too broad and can mask unexpected errors. It's better to catch specific exceptions like ValueError and TypeError. Additionally, the conversion int(arr[0]) on line 32 is not wrapped in a try...except block and could raise an unhandled exception if the element is not convertible to an integer.

I've suggested a refactoring that addresses these points by using more specific exceptions and ensuring all integer conversions are safely handled. The broad except Exception for ops.convert_to_numpy is kept, as it's intended to handle various failures from different backends, especially for symbolic tensors.

def _maybe_convert_to_int(x: Any) -> Any:
    if isinstance(x, int):
        return x
    if isinstance(x, (tuple, list)):
        try:
            return tuple(int(v) for v in x)
        except (ValueError, TypeError):
            return x

    try:
        np_val = ops.convert_to_numpy(x)
    except Exception:
        return x

    if np.isscalar(np_val):
        try:
            return int(np_val)
        except (ValueError, TypeError):
            return x

    arr = np.asarray(np_val).ravel()
    if arr.size == 0:
        return x

    try:
        if arr.size == 1:
            return int(arr[0])
        return tuple(int(v) for v in arr.tolist())
    except (ValueError, TypeError):
        return x