-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Proposed fix for issue #21519: Reshape layer does not handle -1 shape… #21568
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
bc45bcc
a502050
49e6646
d6ebdce
05483ff
223ce95
7225504
1145fb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
import math | ||
|
||
from keras.src import ops | ||
from keras.src.api_export import keras_export | ||
from keras.src.backend.common.keras_tensor import KerasTensor | ||
|
@@ -37,7 +39,21 @@ class Reshape(Layer): | |
|
||
def __init__(self, target_shape, **kwargs): | ||
super().__init__(**kwargs) | ||
self.target_shape = tuple(target_shape) | ||
target_shape = tuple(target_shape) | ||
# test validity of target_shape | ||
if target_shape.count(-1) > 1: | ||
raise ValueError( | ||
"The `target_shape` argument must not contain more than one " | ||
"`-1` value. Received: target_shape={}".format(target_shape) | ||
) | ||
self.target_shape = target_shape | ||
# precalculate all values that might be required | ||
self.need_explicit_shape_for_batch_size_None = ( | ||
target_shape.count(-1) == 1 | ||
) | ||
self.new_size_no_minus_one = math.prod( | ||
d for d in target_shape if d != -1 | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because you removed the |
||
|
||
def compute_output_shape(self, input_shape): | ||
return ( | ||
|
@@ -53,18 +69,22 @@ def compute_output_spec(self, inputs): | |
shape=output_shape, dtype=inputs.dtype, sparse=inputs.sparse | ||
) | ||
|
||
def build(self, input_shape): | ||
sample_output_shape = operation_utils.compute_reshape_output_shape( | ||
input_shape[1:], self.target_shape, "target_shape" | ||
) | ||
self._resolved_target_shape = tuple( | ||
-1 if d is None else d for d in sample_output_shape | ||
) | ||
|
||
def call(self, inputs): | ||
return ops.reshape( | ||
inputs, (ops.shape(inputs)[0],) + self._resolved_target_shape | ||
) | ||
target_shape = self.target_shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's easier to just transfer the code from def call(self, inputs):
potentially_resolved_target_shape = (
operation_utils.compute_reshape_output_shape(
tuple(inputs.shape)[1:], self.target_shape, "target_shape"
)
)
potentially_resolved_target_shape = tuple(
-1 if d is None else d for d in potentially_resolved_target_shape
)
return ops.reshape(
inputs, (ops.shape(inputs)[0],) + potentially_resolved_target_shape
)
|
||
if self.need_explicit_shape_for_batch_size_None and ( | ||
inputs.shape[0] is None | ||
): | ||
input_nonbatch_shape = tuple(inputs.shape[1:]) | ||
if input_nonbatch_shape.count(None) == 0: | ||
inp_nonbatch_size = math.prod(inputs.shape[1:]) | ||
target_shape = tuple( | ||
d | ||
if d != -1 | ||
else (inp_nonbatch_size // self.new_size_no_minus_one) | ||
for d in self.target_shape | ||
) | ||
|
||
return ops.reshape(inputs, (ops.shape(inputs)[0],) + target_shape) | ||
|
||
def get_config(self): | ||
config = {"target_shape": self.target_shape} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
import pytest | ||
from absl.testing import parameterized | ||
|
||
from keras.src import Model | ||
from keras.src import backend | ||
from keras.src import layers | ||
from keras.src import ops | ||
from keras.src import testing | ||
from keras.src.backend.common.keras_tensor import KerasTensor | ||
|
||
|
@@ -100,6 +102,32 @@ def test_reshape_with_dynamic_batch_size_and_minus_one(self): | |
reshaped = backend.compute_output_spec(layer.__call__, input) | ||
self.assertEqual(reshaped.shape, (None, 3, 8)) | ||
|
||
def test_reshape_layer_with_varying_input_size_and_minus_one(self): | ||
input = KerasTensor((None, 6, 4)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
layer = layers.Reshape((-1, 8)) | ||
layer.build(input.shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could remove all the calls to |
||
res = layer(ops.ones((1, 6, 4), dtype="float32")) | ||
self.assertEqual(res.shape, (1, 3, 8)) | ||
res = layer(ops.ones((1, 10, 4), dtype="float32")) | ||
self.assertEqual(res.shape, (1, 5, 8)) | ||
|
||
def test_custom_reshape_model_with_varying_input_size_and_minus_one(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's do a compiled example instead, which exercise a different use case than the example right before. Also, the model can be just the def test_reshape_model_fit_with_varying_input_size_and_minus_one(self):
def generator():
yield (
ops.ones((1, 12, 2), dtype="float32"),
ops.zeros((1, 3, 8), dtype="float32"),
)
yield (
ops.ones((1, 20, 2), dtype="float32"),
ops.zeros((1, 5, 8), dtype="float32"),
)
layer = layers.Reshape((-1, 8))
model = models.Sequential([layer])
model.compile(loss="mean_squared_error")
model.fit(generator()) |
||
class MM(Model): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv = layers.Conv1D(4, 3, padding="same") | ||
self.reshape = layers.Reshape((-1, 8)) | ||
|
||
def call(self, inputs): | ||
x = self.conv(inputs) | ||
return self.reshape(x) | ||
|
||
m = MM() | ||
res = m(ops.ones((1, 6, 2), dtype="float32")) | ||
self.assertEqual(res.shape, (1, 3, 8)) | ||
res = m(ops.ones((1, 10, 2), dtype="float32")) | ||
self.assertEqual(res.shape, (1, 5, 8)) | ||
|
||
def test_reshape_with_dynamic_dim_and_minus_one(self): | ||
input = KerasTensor((4, 6, None, 3)) | ||
layer = layers.Reshape((-1, 3)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use an f-string for this.