diff --git a/keras_preprocessing/sequence.py b/keras_preprocessing/sequence.py index 0e030022..3f95aa4a 100644 --- a/keras_preprocessing/sequence.py +++ b/keras_preprocessing/sequence.py @@ -278,8 +278,16 @@ class TimeseriesGenerator(object): `data[i]`, `data[i-r]`, ... `data[i - length]` are used for create a sample sequence. stride: Period between successive output sequences. - For stride `s`, consecutive output samples would - be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc. + For stride `Stride`, consecutive output samples would + be centered around `data[i]`, `data[i+Stride]`, `data[i+2*Stride]`, etc. + If ``inter_batch_stride`` is None, then `Stride` carries over + between batches: the first sample of a batch is a stride ahead of the + last sample of the previous batch. The formula for the collection of + samples in a batch when ``inter_batch_stride = None`` is + ```[[[data[batch_size*Stride*b + s*Stride + l] + for l in range(length)] for s in range(batch_size)] + for b in range(num_batches)] + ``` start_index: Data points earlier than `start_index` will not be used in the output sequences. This is useful to reserve part of the data for test or validation. @@ -292,6 +300,10 @@ class TimeseriesGenerator(object): in reverse chronological order. batch_size: Number of timeseries samples in each batch (except maybe the last one). + inter_batch_stride: If not None, grants explicit control w.r.t. the + inter-batch stride relationship -- instead of the first sample + of a batch being a stride ahead of the last sample in the previous + batch. # Returns A [Sequence](/utils/#sequence) instance. @@ -305,6 +317,8 @@ class TimeseriesGenerator(object): data = np.array([[i] for i in range(50)]) targets = np.array([[i] for i in range(50)]) + #### + #Test 1 data_gen = TimeseriesGenerator(data, targets, length=10, sampling_rate=2, batch_size=2) @@ -317,6 +331,58 @@ class TimeseriesGenerator(object): [[1], [3], [5], [7], [9]]])) assert np.array_equal(y, np.array([[10], [11]])) + + #### + #Test 2 + data_gen = TimeseriesGenerator(data, targets, + length=5, sampling_rate=1, + batch_size=4, stride=2, + inter_batch_stride=5) + assert len(data_gen) == 9 + + #First batch + assert np.array_equal(data_gen[0][0], + np.array([[ [0], [1], [2], [3], [4]], + [ [2], [3], [4], [5], [6]], + [ [4], [5], [6], [7], [8]], + [ [6], [7], [8], [9], [10]]])) + assert np.array_equal(data_gen[0][1], + np.array([[5], [7], [9], [11]])) + + #Second batch + assert np.array_equal(data_gen[1][0], + np.array([[ [5], [6], [7], [8], [9]], + [ [7], [8], [9], [10], [11]], + [ [9], [10], [11], [12], [13]], + [[11], [12], [13], [14], [15]]])) + assert np.array_equal(data_gen[1][1], + np.array([[10], [12], [14], [16]])) + + #### + #Test 3 + data_gen = TimeseriesGenerator(data, targets, + length=5, sampling_rate=1, + batch_size=4, stride=2, + inter_batch_stride=None) + assert len(data_gen) == 6 + + #First batch + assert np.array_equal(data_gen[0][0], + np.array([[ [0], [1], [2], [3], [4]], + [ [2], [3], [4], [5], [6]], + [ [4], [5], [6], [7], [8]], + [ [6], [7], [8], [9], [10]]])) + assert np.array_equal(data_gen[0][1], + np.array([[5], [7], [9], [11]])) + + #Second batch + assert np.array_equal(data_gen[1][0], + np.array([[ [8], [9], [10], [11], [12]], + [[10], [11], [12], [13], [14]], + [[12], [13], [14], [15], [16]], + [[14], [15], [16], [17], [18]]])) + assert np.array_equal(data_gen[1][1], + np.array([[13], [15], [17], [19]])) ``` """ @@ -327,7 +393,8 @@ def __init__(self, data, targets, length, end_index=None, shuffle=False, reverse=False, - batch_size=128): + batch_size=128, + inter_batch_stride=None): if len(data) != len(targets): raise ValueError('Data and targets have to be' + @@ -340,6 +407,7 @@ def __init__(self, data, targets, length, self.length = length self.sampling_rate = sampling_rate self.stride = stride + self.inter_batch_stride = inter_batch_stride self.start_index = start_index + length if end_index is None: end_index = len(data) - 1 @@ -355,15 +423,22 @@ def __init__(self, data, targets, length, % (self.start_index, self.end_index)) def __len__(self): - return (self.end_index - self.start_index + - self.batch_size * self.stride) // (self.batch_size * self.stride) + if self.inter_batch_stride: + return ((self.end_index - self.start_index + + self.inter_batch_stride) // self.inter_batch_stride) + else: + return (self.end_index - self.start_index + + self.batch_size * self.stride) // (self.batch_size * self.stride) def __getitem__(self, index): if self.shuffle: rows = np.random.randint( self.start_index, self.end_index + 1, size=self.batch_size) else: - i = self.start_index + self.batch_size * self.stride * index + if self.inter_batch_stride: + i = self.start_index + self.inter_batch_stride*index + else: + i = self.start_index + self.batch_size * self.stride * index rows = np.arange(i, min(i + self.batch_size * self.stride, self.end_index + 1), self.stride) @@ -407,7 +482,8 @@ def get_config(self): 'end_index': self.end_index, 'shuffle': self.shuffle, 'reverse': self.reverse, - 'batch_size': self.batch_size + 'batch_size': self.batch_size, + 'inter_batch_stride': self.inter_batch_stride } def to_json(self, **kwargs): diff --git a/tests/sequence_test.py b/tests/sequence_test.py index 36fe1d74..b51a62b3 100644 --- a/tests/sequence_test.py +++ b/tests/sequence_test.py @@ -213,6 +213,28 @@ def test_TimeseriesGenerator(): assert (np.allclose(data_gen[0][1], np.array([[20], [21]]))) + data_gen = sequence.TimeseriesGenerator(data, targets, + length=5, sampling_rate=1, + batch_size=4, stride=2, + inter_batch_stride=5) + assert len(data_gen) == 9 + + assert np.allclose(data_gen[0][0], + np.array([[[0], [1], [2], [3], [4]], + [[2], [3], [4], [5], [6]], + [[4], [5], [6], [7], [8]], + [[6], [7], [8], [9], [10]]])) + assert np.allclose(data_gen[0][1], + np.array([[5], [7], [9], [11]])) + + assert np.allclose(data_gen[1][0], + np.array([[[5], [6], [7], [8], [9]], + [[7], [8], [9], [10], [11]], + [[9], [10], [11], [12], [13]], + [[11], [12], [13], [14], [15]]])) + assert np.allclose(data_gen[1][1], + np.array([[10], [12], [14], [16]])) + data = np.array([np.random.random_sample((1, 2, 3, 4)) for i in range(50)]) targets = np.array([np.random.random_sample((3, 2, 1)) for i in range(50)]) data_gen = sequence.TimeseriesGenerator(data, targets,