diff --git a/keras_preprocessing/image/iterator.py b/keras_preprocessing/image/iterator.py index c62b1d3a..818fef17 100644 --- a/keras_preprocessing/image/iterator.py +++ b/keras_preprocessing/image/iterator.py @@ -51,6 +51,8 @@ def __getitem__(self, idx): 'but the Sequence ' 'has length {length}'.format(idx=idx, length=len(self))) + if idx < 0: + idx = len(self) + idx if self.seed is not None: np.random.seed(self.seed + self.total_batches_seen) self.total_batches_seen += 1 diff --git a/keras_preprocessing/sequence.py b/keras_preprocessing/sequence.py index 74660cef..85adbfa1 100644 --- a/keras_preprocessing/sequence.py +++ b/keras_preprocessing/sequence.py @@ -356,6 +356,8 @@ def __len__(self): self.batch_size * self.stride) // (self.batch_size * self.stride) def __getitem__(self, index): + if index < 0: + index = len(self) + index if self.shuffle: rows = np.random.randint( self.start_index, self.end_index + 1, size=self.batch_size) diff --git a/tests/sequence_test.py b/tests/sequence_test.py index 246ca664..0df51389 100644 --- a/tests/sequence_test.py +++ b/tests/sequence_test.py @@ -140,6 +140,26 @@ def test_TimeseriesGenerator_serde(): assert (data_gen.targets == recovered_gen.targets).all() +def test_TimeseriesGenerator_negative_subscript(): + data = np.array([[i] for i in range(50)]) + targets = np.array([[i] for i in range(50)]) + + data_gen = sequence.TimeseriesGenerator(data, targets, + length=10, + sampling_rate=2, + batch_size=2) + assert len(data_gen) == 20 + assert (np.allclose(data_gen[19][0], data_gen[-1][0])) + assert (np.allclose(data_gen[19][1], data_gen[-1][1])) + assert (np.allclose(data_gen[18][0], data_gen[-2][0])) + assert (np.allclose(data_gen[18][1], data_gen[-2][1])) + + size = len(data_gen) + for i in range(1, size + 1): + assert (np.allclose(data_gen[size - i][0], data_gen[-i][0])) + assert (np.allclose(data_gen[size - i][1], data_gen[-i][1])) + + def test_TimeseriesGenerator(): data = np.array([[i] for i in range(50)]) targets = np.array([[i] for i in range(50)])