-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Open
Labels
Description
The output of ops.prod doesn't work in layer's argument. To make it work, it's required to cast, i.e. int(ops.prod). Now, still its an issue to use ops.prod in call method anyway.
import keras
from keras import layers, ops
import numpy as np
class ProdDenseLayer(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_shape):
self.input_prod = ops.prod(input_shape[1:])
self.dense = layers.Dense(self.input_prod, activation='relu')
self.dense.build(input_shape)
def call(self, inputs):
scale_factor = ops.prod(ops.shape(inputs)[1:])
scaled_inputs = inputs * ops.cast(scale_factor, inputs.dtype)
return self.dense(scaled_inputs)
# Main model using ops.prod
class ProdModel(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.prod_layer1 = ProdDenseLayer()
def build(self, input_shape):
self.prod_layer1.build(input_shape)
def call(self, inputs):
batch_size = ops.shape(inputs)[0]
total_elements = ops.prod(ops.shape(inputs)[1:])
normalized_inputs = inputs / ops.cast(total_elements, inputs.dtype)
x = self.prod_layer1(normalized_inputs)
return x
# Create dummy data
batch_size = 32
input_shape = (10,)
X_train = np.random.randn(batch_size * 10, *input_shape).astype(np.float32)
y_train = np.random.randint(0, 2, (batch_size * 10, 1)).astype(np.float32)
model = ProdModel()
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy']
)
history = model.fit(
X_train, y_train,
epochs=3,
batch_size=batch_size,
validation_split=0.2,
verbose=1
)
model.summary()Epoch 1/3
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_36/1432196851.py in <cell line: 0>()
45 metrics=['accuracy']
46 )
---> 47 history = model.fit(
48 X_train, y_train,
49 epochs=3,
/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
120 # To get the full stack trace, call:
121 # `keras.config.disable_traceback_filtering()`
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
/tmp/ipykernel_36/1432196851.py in build(self, input_shape)
25
26 def build(self, input_shape):
---> 27 self.prod_layer1.build(input_shape)
28
29 def call(self, inputs):
/tmp/ipykernel_36/1432196851.py in build(self, input_shape)
11 self.input_prod = ops.prod(input_shape[1:])
12 self.dense = layers.Dense(self.input_prod, activation='relu')
---> 13 self.dense.build(input_shape)
14
15 def call(self, inputs):
ValueError: Invalid dtype: <property object at 0x79c67d66d3f0>