@@ -278,8 +278,16 @@ class TimeseriesGenerator(object):
278278 `data[i]`, `data[i-r]`, ... `data[i - length]`
279279 are used for create a sample sequence.
280280 stride: Period between successive output sequences.
281- For stride `s`, consecutive output samples would
282- be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc.
281+ For stride `Stride`, consecutive output samples would
282+ be centered around `data[i]`, `data[i+Stride]`, `data[i+2*Stride]`, etc.
283+ If ``inter_batch_stride`` is None, then `Stride` carries over
284+ between batches: the first sample of a batch is a stride ahead of the
285+ last sample of the previous batch. The formula for the collection of
286+ samples in a batch when ``inter_batch_stride = None`` is
287+ ```[[[data[batch_size*Stride*b + s*Stride + l]
288+ for l in range(length)] for s in range(batch_size)]
289+ for b in range(num_batches)]
290+ ```
283291 start_index: Data points earlier than `start_index` will not be used
284292 in the output sequences. This is useful to reserve part of the
285293 data for test or validation.
@@ -292,6 +300,10 @@ class TimeseriesGenerator(object):
292300 in reverse chronological order.
293301 batch_size: Number of timeseries samples in each batch
294302 (except maybe the last one).
303+ inter_batch_stride: If not None, grants explicit control w.r.t. the
304+ inter-batch stride relationship -- instead of the first sample
305+ of a batch being a stride ahead of the last sample in the previous
306+ batch.
295307
296308 # Returns
297309 A [Sequence](/utils/#sequence) instance.
@@ -305,6 +317,8 @@ class TimeseriesGenerator(object):
305317 data = np.array([[i] for i in range(50)])
306318 targets = np.array([[i] for i in range(50)])
307319
320+ ####
321+ #Test 1
308322 data_gen = TimeseriesGenerator(data, targets,
309323 length=10, sampling_rate=2,
310324 batch_size=2)
@@ -317,6 +331,58 @@ class TimeseriesGenerator(object):
317331 [[1], [3], [5], [7], [9]]]))
318332 assert np.array_equal(y,
319333 np.array([[10], [11]]))
334+
335+ ####
336+ #Test 2
337+ data_gen = TimeseriesGenerator(data, targets,
338+ length=5, sampling_rate=1,
339+ batch_size=4, stride=2,
340+ inter_batch_stride=5)
341+ assert len(data_gen) == 9
342+
343+ #First batch
344+ assert np.array_equal(data_gen[0][0],
345+ np.array([[ [0], [1], [2], [3], [4]],
346+ [ [2], [3], [4], [5], [6]],
347+ [ [4], [5], [6], [7], [8]],
348+ [ [6], [7], [8], [9], [10]]]))
349+ assert np.array_equal(data_gen[0][1],
350+ np.array([[5], [7], [9], [11]]))
351+
352+ #Second batch
353+ assert np.array_equal(data_gen[1][0],
354+ np.array([[ [5], [6], [7], [8], [9]],
355+ [ [7], [8], [9], [10], [11]],
356+ [ [9], [10], [11], [12], [13]],
357+ [[11], [12], [13], [14], [15]]]))
358+ assert np.array_equal(data_gen[1][1],
359+ np.array([[10], [12], [14], [16]]))
360+
361+ ####
362+ #Test 3
363+ data_gen = TimeseriesGenerator(data, targets,
364+ length=5, sampling_rate=1,
365+ batch_size=4, stride=2,
366+ inter_batch_stride=None)
367+ assert len(data_gen) == 6
368+
369+ #First batch
370+ assert np.array_equal(data_gen[0][0],
371+ np.array([[ [0], [1], [2], [3], [4]],
372+ [ [2], [3], [4], [5], [6]],
373+ [ [4], [5], [6], [7], [8]],
374+ [ [6], [7], [8], [9], [10]]]))
375+ assert np.array_equal(data_gen[0][1],
376+ np.array([[5], [7], [9], [11]]))
377+
378+ #Second batch
379+ assert np.array_equal(data_gen[1][0],
380+ np.array([[ [8], [9], [10], [11], [12]],
381+ [[10], [11], [12], [13], [14]],
382+ [[12], [13], [14], [15], [16]],
383+ [[14], [15], [16], [17], [18]]]))
384+ assert np.array_equal(data_gen[1][1],
385+ np.array([[13], [15], [17], [19]]))
320386 ```
321387 """
322388
@@ -327,6 +393,7 @@ def __init__(self, data, targets, length,
327393 end_index = None ,
328394 shuffle = False ,
329395 reverse = False ,
396+ inter_batch_stride = None ,
330397 batch_size = 128 ):
331398
332399 if len (data ) != len (targets ):
@@ -340,6 +407,7 @@ def __init__(self, data, targets, length,
340407 self .length = length
341408 self .sampling_rate = sampling_rate
342409 self .stride = stride
410+ self .inter_batch_stride = inter_batch_stride
343411 self .start_index = start_index + length
344412 if end_index is None :
345413 end_index = len (data ) - 1
@@ -355,15 +423,22 @@ def __init__(self, data, targets, length,
355423 % (self .start_index , self .end_index ))
356424
357425 def __len__ (self ):
358- return (self .end_index - self .start_index +
359- self .batch_size * self .stride ) // (self .batch_size * self .stride )
426+ if self .inter_batch_stride :
427+ return ((self .end_index - self .start_index +
428+ self .inter_batch_stride ) // self .inter_batch_stride )
429+ else :
430+ return (self .end_index - self .start_index +
431+ self .batch_size * self .stride ) // (self .batch_size * self .stride )
360432
361433 def __getitem__ (self , index ):
362434 if self .shuffle :
363435 rows = np .random .randint (
364436 self .start_index , self .end_index + 1 , size = self .batch_size )
365437 else :
366- i = self .start_index + self .batch_size * self .stride * index
438+ if self .inter_batch_stride :
439+ i = self .start_index + self .inter_batch_stride * index
440+ else :
441+ i = self .start_index + self .batch_size * self .stride * index
367442 rows = np .arange (i , min (i + self .batch_size *
368443 self .stride , self .end_index + 1 ), self .stride )
369444
@@ -403,6 +478,7 @@ def get_config(self):
403478 'length' : self .length ,
404479 'sampling_rate' : self .sampling_rate ,
405480 'stride' : self .stride ,
481+ 'inter_batch_stride' : self .inter_batch_stride ,
406482 'start_index' : self .start_index ,
407483 'end_index' : self .end_index ,
408484 'shuffle' : self .shuffle ,
0 commit comments