Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Commit b39f31f

Browse files
author
McCabe, Robert J
committed
Test for inter_batch_stride TimeseriesGenerator option
1 parent feb49b7 commit b39f31f

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

keras_preprocessing/sequence.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,8 @@ def __init__(self, data, targets, length,
393393
end_index=None,
394394
shuffle=False,
395395
reverse=False,
396-
inter_batch_stride=None,
397-
batch_size=128):
396+
batch_size=128,
397+
inter_batch_stride=None):
398398

399399
if len(data) != len(targets):
400400
raise ValueError('Data and targets have to be' +
@@ -478,12 +478,12 @@ def get_config(self):
478478
'length': self.length,
479479
'sampling_rate': self.sampling_rate,
480480
'stride': self.stride,
481-
'inter_batch_stride': self.inter_batch_stride,
482481
'start_index': self.start_index,
483482
'end_index': self.end_index,
484483
'shuffle': self.shuffle,
485484
'reverse': self.reverse,
486-
'batch_size': self.batch_size
485+
'batch_size': self.batch_size,
486+
'inter_batch_stride': self.inter_batch_stride
487487
}
488488

489489
def to_json(self, **kwargs):

tests/sequence_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,28 @@ def test_TimeseriesGenerator():
213213
assert (np.allclose(data_gen[0][1],
214214
np.array([[20], [21]])))
215215

216+
data_gen = sequence.TimeseriesGenerator(data, targets,
217+
length=5, sampling_rate=1,
218+
batch_size=4, stride=2,
219+
inter_batch_stride=5)
220+
assert len(data_gen) == 9
221+
222+
assert np.allclose(data_gen[0][0],
223+
np.array([[[0], [1], [2], [3], [4]],
224+
[[2], [3], [4], [5], [6]],
225+
[[4], [5], [6], [7], [8]],
226+
[[6], [7], [8], [9], [10]]]))
227+
assert np.allclose(data_gen[0][1],
228+
np.array([[5], [7], [9], [11]]))
229+
230+
assert np.allclose(data_gen[1][0],
231+
np.array([[[5], [6], [7], [8], [9]],
232+
[[7], [8], [9], [10], [11]],
233+
[[9], [10], [11], [12], [13]],
234+
[[11], [12], [13], [14], [15]]]))
235+
assert np.allclose(data_gen[1][1],
236+
np.array([[10], [12], [14], [16]]))
237+
216238
data = np.array([np.random.random_sample((1, 2, 3, 4)) for i in range(50)])
217239
targets = np.array([np.random.random_sample((3, 2, 1)) for i in range(50)])
218240
data_gen = sequence.TimeseriesGenerator(data, targets,

0 commit comments

Comments
 (0)