Skip to content

Proposed fix for issue #21519: Reshape layer does not handle -1 shape… #21568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
44 changes: 32 additions & 12 deletions keras/src/layers/reshaping/reshape.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.backend.common.keras_tensor import KerasTensor
Expand Down Expand Up @@ -37,7 +39,21 @@ class Reshape(Layer):

def __init__(self, target_shape, **kwargs):
super().__init__(**kwargs)
self.target_shape = tuple(target_shape)
target_shape = tuple(target_shape)
# test validity of target_shape
if target_shape.count(-1) > 1:
raise ValueError(
"The `target_shape` argument must not contain more than one "
"`-1` value. Received: target_shape={}".format(target_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use an f-string for this.

)
self.target_shape = target_shape
# precalculate all values that might be required
self.need_explicit_shape_for_batch_size_None = (
target_shape.count(-1) == 1
)
self.new_size_no_minus_one = math.prod(
d for d in target_shape if d != -1
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because you removed the build method, add self.built = True at the end of __init__.


def compute_output_shape(self, input_shape):
return (
Expand All @@ -53,18 +69,22 @@ def compute_output_spec(self, inputs):
shape=output_shape, dtype=inputs.dtype, sparse=inputs.sparse
)

def build(self, input_shape):
sample_output_shape = operation_utils.compute_reshape_output_shape(
input_shape[1:], self.target_shape, "target_shape"
)
self._resolved_target_shape = tuple(
-1 if d is None else d for d in sample_output_shape
)

def call(self, inputs):
return ops.reshape(
inputs, (ops.shape(inputs)[0],) + self._resolved_target_shape
)
target_shape = self.target_shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's easier to just transfer the code from build here:

    def call(self, inputs):
        potentially_resolved_target_shape = (
            operation_utils.compute_reshape_output_shape(
                tuple(inputs.shape)[1:], self.target_shape, "target_shape"
            )
        )
        potentially_resolved_target_shape = tuple(
            -1 if d is None else d for d in potentially_resolved_target_shape
        )
        return ops.reshape(
            inputs, (ops.shape(inputs)[0],) + potentially_resolved_target_shape
        )
  • you don't have to reimplement the computation of the missing dimension
  • you don't have to deal with errors if the number of values is not divisible by self.new_size_no_minus_one (right now, that check is missing)
  • you don't need self.need_explicit_shape_for_batch_size_None and self.new_size_no_minus_one

if self.need_explicit_shape_for_batch_size_None and (
inputs.shape[0] is None
):
input_nonbatch_shape = tuple(inputs.shape[1:])
if input_nonbatch_shape.count(None) == 0:
inp_nonbatch_size = math.prod(inputs.shape[1:])
target_shape = tuple(
d
if d != -1
else (inp_nonbatch_size // self.new_size_no_minus_one)
for d in self.target_shape
)

return ops.reshape(inputs, (ops.shape(inputs)[0],) + target_shape)

def get_config(self):
config = {"target_shape": self.target_shape}
Expand Down
28 changes: 28 additions & 0 deletions keras/src/layers/reshaping/reshape_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
from absl.testing import parameterized

from keras.src import Model
from keras.src import backend
from keras.src import layers
from keras.src import ops
from keras.src import testing
from keras.src.backend.common.keras_tensor import KerasTensor

Expand Down Expand Up @@ -100,6 +102,32 @@ def test_reshape_with_dynamic_batch_size_and_minus_one(self):
reshaped = backend.compute_output_spec(layer.__call__, input)
self.assertEqual(reshaped.shape, (None, 3, 8))

def test_reshape_layer_with_varying_input_size_and_minus_one(self):
input = KerasTensor((None, 6, 4))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input = KerasTensor((None, 6, 4)) is no longer needed in this test, remove.

layer = layers.Reshape((-1, 8))
layer.build(input.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could remove all the calls to build in this test and others, it's no longer needed (just to make the code shorter).

res = layer(ops.ones((1, 6, 4), dtype="float32"))
self.assertEqual(res.shape, (1, 3, 8))
res = layer(ops.ones((1, 10, 4), dtype="float32"))
self.assertEqual(res.shape, (1, 5, 8))

def test_custom_reshape_model_with_varying_input_size_and_minus_one(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do a compiled example instead, which exercise a different use case than the example right before. Also, the model can be just the Reshape layer:

    def test_reshape_model_fit_with_varying_input_size_and_minus_one(self):
        def generator():
            yield (
                ops.ones((1, 12, 2), dtype="float32"),
                ops.zeros((1, 3, 8), dtype="float32"),
            )
            yield (
                ops.ones((1, 20, 2), dtype="float32"),
                ops.zeros((1, 5, 8), dtype="float32"),
            )

        layer = layers.Reshape((-1, 8))
        model = models.Sequential([layer])
        model.compile(loss="mean_squared_error")
        model.fit(generator())

class MM(Model):
def __init__(self):
super().__init__()
self.conv = layers.Conv1D(4, 3, padding="same")
self.reshape = layers.Reshape((-1, 8))

def call(self, inputs):
x = self.conv(inputs)
return self.reshape(x)

m = MM()
res = m(ops.ones((1, 6, 2), dtype="float32"))
self.assertEqual(res.shape, (1, 3, 8))
res = m(ops.ones((1, 10, 2), dtype="float32"))
self.assertEqual(res.shape, (1, 5, 8))

def test_reshape_with_dynamic_dim_and_minus_one(self):
input = KerasTensor((4, 6, None, 3))
layer = layers.Reshape((-1, 3))
Expand Down
Loading