Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Commit feb49b7

Browse files
author
McCabe, Robert J
committed
Add inter_batch_stride param to TimeSeriesGenerator
The original TimeSeriesGenerator emits batches where the first sample is a stride ahead of the previous batch's last sample. Many times more control is required (especially if feeding to a stateful RNN). The inter_batch_stride option allows user to explicitly specify the inter-batch stride relationships. If this option is None (or not supplied), the original behavior is maintained.
1 parent c1a7c7f commit feb49b7

File tree

1 file changed

+81
-5
lines changed

1 file changed

+81
-5
lines changed

keras_preprocessing/sequence.py

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,16 @@ class TimeseriesGenerator(object):
278278
`data[i]`, `data[i-r]`, ... `data[i - length]`
279279
are used for create a sample sequence.
280280
stride: Period between successive output sequences.
281-
For stride `s`, consecutive output samples would
282-
be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc.
281+
For stride `Stride`, consecutive output samples would
282+
be centered around `data[i]`, `data[i+Stride]`, `data[i+2*Stride]`, etc.
283+
If ``inter_batch_stride`` is None, then `Stride` carries over
284+
between batches: the first sample of a batch is a stride ahead of the
285+
last sample of the previous batch. The formula for the collection of
286+
samples in a batch when ``inter_batch_stride = None`` is
287+
```[[[data[batch_size*Stride*b + s*Stride + l]
288+
for l in range(length)] for s in range(batch_size)]
289+
for b in range(num_batches)]
290+
```
283291
start_index: Data points earlier than `start_index` will not be used
284292
in the output sequences. This is useful to reserve part of the
285293
data for test or validation.
@@ -292,6 +300,10 @@ class TimeseriesGenerator(object):
292300
in reverse chronological order.
293301
batch_size: Number of timeseries samples in each batch
294302
(except maybe the last one).
303+
inter_batch_stride: If not None, grants explicit control w.r.t. the
304+
inter-batch stride relationship -- instead of the first sample
305+
of a batch being a stride ahead of the last sample in the previous
306+
batch.
295307
296308
# Returns
297309
A [Sequence](/utils/#sequence) instance.
@@ -305,6 +317,8 @@ class TimeseriesGenerator(object):
305317
data = np.array([[i] for i in range(50)])
306318
targets = np.array([[i] for i in range(50)])
307319
320+
####
321+
#Test 1
308322
data_gen = TimeseriesGenerator(data, targets,
309323
length=10, sampling_rate=2,
310324
batch_size=2)
@@ -317,6 +331,58 @@ class TimeseriesGenerator(object):
317331
[[1], [3], [5], [7], [9]]]))
318332
assert np.array_equal(y,
319333
np.array([[10], [11]]))
334+
335+
####
336+
#Test 2
337+
data_gen = TimeseriesGenerator(data, targets,
338+
length=5, sampling_rate=1,
339+
batch_size=4, stride=2,
340+
inter_batch_stride=5)
341+
assert len(data_gen) == 9
342+
343+
#First batch
344+
assert np.array_equal(data_gen[0][0],
345+
np.array([[ [0], [1], [2], [3], [4]],
346+
[ [2], [3], [4], [5], [6]],
347+
[ [4], [5], [6], [7], [8]],
348+
[ [6], [7], [8], [9], [10]]]))
349+
assert np.array_equal(data_gen[0][1],
350+
np.array([[5], [7], [9], [11]]))
351+
352+
#Second batch
353+
assert np.array_equal(data_gen[1][0],
354+
np.array([[ [5], [6], [7], [8], [9]],
355+
[ [7], [8], [9], [10], [11]],
356+
[ [9], [10], [11], [12], [13]],
357+
[[11], [12], [13], [14], [15]]]))
358+
assert np.array_equal(data_gen[1][1],
359+
np.array([[10], [12], [14], [16]]))
360+
361+
####
362+
#Test 3
363+
data_gen = TimeseriesGenerator(data, targets,
364+
length=5, sampling_rate=1,
365+
batch_size=4, stride=2,
366+
inter_batch_stride=None)
367+
assert len(data_gen) == 6
368+
369+
#First batch
370+
assert np.array_equal(data_gen[0][0],
371+
np.array([[ [0], [1], [2], [3], [4]],
372+
[ [2], [3], [4], [5], [6]],
373+
[ [4], [5], [6], [7], [8]],
374+
[ [6], [7], [8], [9], [10]]]))
375+
assert np.array_equal(data_gen[0][1],
376+
np.array([[5], [7], [9], [11]]))
377+
378+
#Second batch
379+
assert np.array_equal(data_gen[1][0],
380+
np.array([[ [8], [9], [10], [11], [12]],
381+
[[10], [11], [12], [13], [14]],
382+
[[12], [13], [14], [15], [16]],
383+
[[14], [15], [16], [17], [18]]]))
384+
assert np.array_equal(data_gen[1][1],
385+
np.array([[13], [15], [17], [19]]))
320386
```
321387
"""
322388

@@ -327,6 +393,7 @@ def __init__(self, data, targets, length,
327393
end_index=None,
328394
shuffle=False,
329395
reverse=False,
396+
inter_batch_stride=None,
330397
batch_size=128):
331398

332399
if len(data) != len(targets):
@@ -340,6 +407,7 @@ def __init__(self, data, targets, length,
340407
self.length = length
341408
self.sampling_rate = sampling_rate
342409
self.stride = stride
410+
self.inter_batch_stride = inter_batch_stride
343411
self.start_index = start_index + length
344412
if end_index is None:
345413
end_index = len(data) - 1
@@ -355,15 +423,22 @@ def __init__(self, data, targets, length,
355423
% (self.start_index, self.end_index))
356424

357425
def __len__(self):
358-
return (self.end_index - self.start_index +
359-
self.batch_size * self.stride) // (self.batch_size * self.stride)
426+
if self.inter_batch_stride:
427+
return ((self.end_index - self.start_index +
428+
self.inter_batch_stride) // self.inter_batch_stride)
429+
else:
430+
return (self.end_index - self.start_index +
431+
self.batch_size * self.stride) // (self.batch_size * self.stride)
360432

361433
def __getitem__(self, index):
362434
if self.shuffle:
363435
rows = np.random.randint(
364436
self.start_index, self.end_index + 1, size=self.batch_size)
365437
else:
366-
i = self.start_index + self.batch_size * self.stride * index
438+
if self.inter_batch_stride:
439+
i = self.start_index + self.inter_batch_stride*index
440+
else:
441+
i = self.start_index + self.batch_size * self.stride * index
367442
rows = np.arange(i, min(i + self.batch_size *
368443
self.stride, self.end_index + 1), self.stride)
369444

@@ -403,6 +478,7 @@ def get_config(self):
403478
'length': self.length,
404479
'sampling_rate': self.sampling_rate,
405480
'stride': self.stride,
481+
'inter_batch_stride': self.inter_batch_stride,
406482
'start_index': self.start_index,
407483
'end_index': self.end_index,
408484
'shuffle': self.shuffle,

0 commit comments

Comments
 (0)