diff --git a/keras_preprocessing/sequence.py b/keras_preprocessing/sequence.py index 9c6d6da4..977cbbf9 100644 --- a/keras_preprocessing/sequence.py +++ b/keras_preprocessing/sequence.py @@ -8,10 +8,12 @@ import numpy as np import random from six.moves import range +from math import ceil from . import get_keras_submodule keras_utils = get_keras_submodule('utils') +from keras.utils.data_utils import Sequence def pad_sequences(sequences, maxlen=None, dtype='int32', @@ -213,7 +215,7 @@ def skipgrams(sequence, vocabulary_size, random.shuffle(words) couples += [[words[i % len(words)], - random.randint(1, vocabulary_size - 1)] + random.randint(1, vocabulary_size - 1)] for i in range(num_negative_samples)] if categorical: labels += [[1, 0]] * num_negative_samples @@ -250,121 +252,245 @@ def _remove_long_seq(maxlen, seq, label): return new_seq, new_label -class TimeseriesGenerator(keras_utils.Sequence): +class TimeseriesGenerator(Sequence): """Utility class for generating batches of temporal data. - This class takes in a sequence of data-points gathered at equal intervals, along with time series parameters such as stride, length of history, etc., to produce batches for training/validation. - # Arguments data: Indexable generator (such as list or Numpy array) containing consecutive data points (timesteps). - The data should be at 2D, and axis 0 is expected - to be the time dimension. + The data should be convertible into a 1D numpy array, + if 2D or more, axis 0 is expected to be the time dimension. targets: Targets corresponding to timesteps in `data`. - It should have same length as `data`. - length: Length of the output sequences (in number of timesteps). + It should have at least the same length as `data`. + length: length of the output sub-sequence before sampling + (depreciated, use hlength instead). sampling_rate: Period between successive individual timesteps - within sequences. For rate `r`, timesteps - `data[i]`, `data[i-r]`, ... `data[i - length]` - are used for create a sample sequence. + within sequences, `length` has to be a multiple of `sampling_rate`. stride: Period between successive output sequences. For stride `s`, consecutive output samples would be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc. - start_index: Data points earlier than `start_index` will not be used - in the output sequences. This is useful to reserve part of the - data for test or validation. - end_index: Data points later than `end_index` will not be used - in the output sequences. This is useful to reserve part of the - data for test or validation. + start_index, end_index: Data points earlier than `start_index` + or later than `end_index` will not be used in the output sequences. + This is useful to reserve part of the data for test or validation. shuffle: Whether to shuffle output samples, or instead draw them in chronological order. - reverse: Boolean: if `true`, timesteps in each output sample will be + reverse: Boolean: if `True`, timesteps in each output sample will be in reverse chronological order. - batch_size: Number of timeseries samples in each batch - (except maybe the last one). + batch_size: Number of timeseries samples in each batch. + hlength: Effective "history" length of the output sub-sequences after + sampling (in number of timesteps). + gap: prediction gap, i.e. numer of timesteps ahead (usually zero, or + same as samplig_rate) + `x=data[i - (hlength-1)*sampling_rate - gap:i-gap+1:sampling_rate]` + and `y=targets[i]` + are used respectively as sample sequence `x` and target value `y`. + target_seq: Boolean: if 'True', produces full shifted sequence targets: + If target_seq is set, for sampling rate `r`, timesteps + `data[i - (hlength-1)*r - gap]`, ..., `data[i-r-gap]`, `data[i-gap]` + and + `targets[i - (hlength-1)*r]`, ..., `data[i-r]`, `data[i]` + are used respectively as sample sequence `x` and target sequence `y`. + dtype: force sample/target dtype (default is None) + stateful: helper to check if parameters are valid for stateful learning + (experimental). + # Returns - A [Sequence](/utils/#sequence) instance. + A [Sequence](/utils/#sequence) instance of tuples (x,y) + where x is a numpy array of shape (batch_size, hlength, ...) + and y is a numpy array of shape (batch_size, ...) if target_seq is `False` + or (batch_size, hlength, ...) if target_seq is `True`. + If not specified, output dtype is infered from data dtype. # Examples - ```python from keras.preprocessing.sequence import TimeseriesGenerator import numpy as np - data = np.array([[i] for i in range(50)]) - targets = np.array([[i] for i in range(50)]) - - data_gen = TimeseriesGenerator(data, targets, - length=10, sampling_rate=2, - batch_size=2) - assert len(data_gen) == 20 - - batch_0 = data_gen[0] - x, y = batch_0 - assert np.array_equal(x, - np.array([[[0], [2], [4], [6], [8]], - [[1], [3], [5], [7], [9]]])) - assert np.array_equal(y, - np.array([[10], [11]])) + txt = bytearray("Keras is simple.", 'utf-8') + data_gen = TimeseriesGenerator(txt, txt, hlength=10, batch_size=1, gap=1) + + for i in range(len(data_gen)): + print(data_gen[i][0].tostring(), "->'%s'" % data_gen[i][1].tostring()) + + assert data_gen[-1][0].shape == (1, 10) and data_gen[-1][1].shape == (1,) + assert data_gen[-1][0].tostring() == u" is simple" + assert data_gen[-1][1].tostring() == u"." + + t = np.linspace(0,20*np.pi, num=1000) # time + x = np.sin(np.cos(3*t)) # input signa + y = np.sin(np.cos(6*t+4)) # output signal + + # define recurrent model + from keras.models import Model + from keras.layers import Input, SimpleRNN, LSTM, GRU,Dense + + inputs = Input(batch_shape=(None, None, 1)) + l = SimpleRNN(100, return_sequences=True)(inputs) + l = Dense(100, activation='tanh')(l) + preds = Dense(1, activation='linear')(l) + model = Model(inputs=inputs, outputs=preds) + model.compile(loss='mean_squared_error', optimizer='Nadam') + + # fit model to sequence + xx = np.expand_dims(x, axis=-1) + g = TimeseriesGenerator(xx, y, hlength=100, target_seq=True, shuffle=True) + model.fit_generator(g, steps_per_epoch=len(g), epochs=20, shuffle=True) + + # plot prediction + x2 = np.reshape(x,(1,x.shape[0],1)) + z = model.predict(x2) + + import matplotlib.pyplot as plt + plt.figure(figsize=(12,12)) + plt.title('Phase representation') + plt.plot(x,y.flatten(), color='black') + plt.plot(x,z.flatten(), dashes=[8,1], label='prediction', color='orange') + plt.xlabel('input') + plt.ylabel('output') + plt.grid() + plt.show() + ``` """ - def __init__(self, data, targets, length, + def __init__(self, data, targets, length=None, sampling_rate=1, stride=1, - start_index=0, - end_index=None, + start_index=0, end_index=None, shuffle=False, reverse=False, - batch_size=128): - self.data = data - self.targets = targets - self.length = length + batch_size=128, + hlength=None, + target_seq=False, + gap=0, + dtype=None, + stateful=False): + + # Sanity check + + if sampling_rate <= 0: + raise ValueError('`sampling_rate` must be strictly positive.') + if stride <= 0: + raise ValueError('`stride` must be strictly positive.') + if batch_size <= 0: + raise ValueError('`batch_size` must be strictly positive.') + if len(data) > len(targets): + raise ValueError('`targets` has to be at least as long as `data`.') + + if hlength is None: + if length % sampling_rate != 0: + raise ValueError( + "`length` has to be a multiple of `sampling_rate`." + " For instance, `length=%i` would do." % (2 * sampling_rate)) + hlength = length // sampling_rate + + if gap % sampling_rate != 0: + warnings.warn( + "Unless you know what you do, `gap` should be zero or" + " a multiple of `sampling_rate`.", UserWarning) + + self.hlength = hlength + assert self.hlength > 0 + + self.data = np.asarray(data) + self.targets = np.asarray(targets) + + # FIXME: targets must be 2D for sequences output + if target_seq and len(self.targets.shape) < 2: + self.targets = np.expand_dims(self.targets, axis=-1) + + if dtype is None: + self.data_type = self.data.dtype + self.targets_type = self.targets.dtype + else: + self.data_type = dtype + self.targets_type = dtype + + # Check if parameters are stateful-compatible + if stateful: + if shuffle: + raise ValueError('Do not shuffle for stateful learning.') + if self.hlength % batch_size != 0: + raise ValueError("For stateful learning, `hlength` has to be" + "a multiple of `batch_size`." + "For instance, `hlength=%i` would do." + % (3 * batch_size)) + if stride != (self.hlength // batch_size) * sampling_rate: + raise ValueError( + '`stride=%i`, for these parameters set `stride=%i`.' + % (stride, (hlength // batch_size) * sampling_rate)) + self.sampling_rate = sampling_rate + self.batch_size = batch_size + assert stride > 0 self.stride = stride - self.start_index = start_index + length + self.gap = gap + + sliding_win_size = (self.hlength - 1) * sampling_rate + gap + self.start_index = start_index + sliding_win_size if end_index is None: - end_index = len(data) - 1 + end_index = len(data) + assert end_index <= len(data) self.end_index = end_index - self.shuffle = shuffle self.reverse = reverse - self.batch_size = batch_size + self.target_seq = target_seq + + self.len = int(ceil(float(self.end_index - self.start_index) / + (self.batch_size * self.stride))) + if self.len <= 0: + err = "This configuration gives no output, try with a longer" + " input sequence or different parameters." + raise ValueError(err) + + assert self.len > 0 - if self.start_index > self.end_index: - raise ValueError('`start_index+length=%i > end_index=%i` ' - 'is disallowed, as no part of the sequence ' - 'would be left to be used as current step.' - % (self.start_index, self.end_index)) + self.perm = np.arange(self.start_index, self.end_index) + if shuffle: + np.random.shuffle(self.perm) def __len__(self): - return (self.end_index - self.start_index + - self.batch_size * self.stride) // (self.batch_size * self.stride) + return self.len def _empty_batch(self, num_rows): - samples_shape = [num_rows, self.length // self.sampling_rate] + samples_shape = [num_rows, self.hlength] samples_shape.extend(self.data.shape[1:]) - targets_shape = [num_rows] + if self.target_seq: + targets_shape = [num_rows, self.hlength] + else: + targets_shape = [num_rows] targets_shape.extend(self.targets.shape[1:]) - return np.empty(samples_shape), np.empty(targets_shape) + + return np.empty(samples_shape, dtype=self.data_type), np.empty( + targets_shape, dtype=self.targets_type) def __getitem__(self, index): - if self.shuffle: - rows = np.random.randint( - self.start_index, self.end_index + 1, size=self.batch_size) - else: - i = self.start_index + self.batch_size * self.stride * index - rows = np.arange(i, min(i + self.batch_size * - self.stride, self.end_index + 1), self.stride) + while index < 0: + index += self.len + assert index < self.len + batch_start = self.batch_size * self.stride * index + rows = np.arange(batch_start, min(batch_start + self.batch_size * + self.stride, + self.end_index - self.start_index), + self.stride) + rows = self.perm[rows] samples, targets = self._empty_batch(len(rows)) for j, row in enumerate(rows): - indices = range(rows[j] - self.length, rows[j], self.sampling_rate) + indices = range(rows[j] - self.gap - + (self.hlength - 1) * self.sampling_rate, + rows[j] - self.gap + 1, self.sampling_rate) samples[j] = self.data[indices] - targets[j] = self.targets[rows[j]] + if self.target_seq: + shifted_indices = range(rows[j] - (self.hlength - 1) * + self.sampling_rate, + rows[j] + 1, self.sampling_rate) + targets[j] = self.targets[shifted_indices] + else: + targets[j] = self.targets[rows[j]] if self.reverse: return samples[:, ::-1, ...], targets return samples, targets diff --git a/tests/sequence_test.py b/tests/sequence_test.py index 5400f8c9..c7ad461b 100644 --- a/tests/sequence_test.py +++ b/tests/sequence_test.py @@ -103,116 +103,158 @@ def test_remove_long_seq(): assert new_label == ['a'] +from keras_preprocessing.sequence import TimeseriesGenerator + + def test_TimeseriesGenerator(): data = np.array([[i] for i in range(50)]) targets = np.array([[i] for i in range(50)]) - data_gen = sequence.TimeseriesGenerator(data, targets, - length=10, - sampling_rate=2, - batch_size=2) + data_gen = TimeseriesGenerator(data, targets, + length=10, sampling_rate=2, + batch_size=2, gap=2) assert len(data_gen) == 20 - assert (np.allclose(data_gen[0][0], - np.array([[[0], [2], [4], [6], [8]], - [[1], [3], [5], [7], [9]]]))) - assert (np.allclose(data_gen[0][1], - np.array([[10], [11]]))) - assert (np.allclose(data_gen[1][0], - np.array([[[2], [4], [6], [8], [10]], - [[3], [5], [7], [9], [11]]]))) - assert (np.allclose(data_gen[1][1], - np.array([[12], [13]]))) - - data_gen = sequence.TimeseriesGenerator(data, targets, - length=10, - sampling_rate=2, - reverse=True, - batch_size=2) + assert (np.array_equal(data_gen[0][0], + np.array([[[0], [2], [4], [6], [8]], + [[1], [3], [5], [7], [9]]]))) + assert (np.array_equal(data_gen[0][1], + np.array([[10], [11]]))) + assert (np.array_equal(data_gen[1][0], + np.array([[[2], [4], [6], [8], [10]], + [[3], [5], [7], [9], [11]]]))) + assert (np.array_equal(data_gen[1][1], + np.array([[12], [13]]))) + + data_gen = TimeseriesGenerator(data, targets, + length=10, sampling_rate=2, reverse=True, + batch_size=2, gap=2) assert len(data_gen) == 20 - assert (np.allclose(data_gen[0][0], - np.array([[[8], [6], [4], [2], [0]], - [[9], [7], [5], [3], [1]]]))) - assert (np.allclose(data_gen[0][1], - np.array([[10], [11]]))) - - data_gen = sequence.TimeseriesGenerator(data, targets, - length=10, - sampling_rate=2, - shuffle=True, - batch_size=1) + assert (np.array_equal(data_gen[0][0], + np.array([[[8], [6], [4], [2], [0]], + [[9], [7], [5], [3], [1]]]))) + assert (np.array_equal(data_gen[0][1], + np.array([[10], [11]]))) + + data_gen = TimeseriesGenerator(data, targets, + length=10, sampling_rate=2, shuffle=True, + batch_size=1, gap=2) batch = data_gen[0] r = batch[1][0][0] - assert (np.allclose(batch[0], - np.array([[[r - 10], - [r - 8], - [r - 6], - [r - 4], - [r - 2]]]))) - assert (np.allclose(batch[1], np.array([[r], ]))) - - data_gen = sequence.TimeseriesGenerator(data, targets, - length=10, - sampling_rate=2, - stride=2, - batch_size=2) + assert (np.array_equal(batch[0], + np.array([[[r - 10], + [r - 8], + [r - 6], + [r - 4], + [r - 2]]]))) + assert (np.array_equal(batch[1], np.array([[r], ]))) + + data_gen = TimeseriesGenerator(data, targets, + length=10, sampling_rate=2, stride=2, + batch_size=2, gap=2) assert len(data_gen) == 10 - assert (np.allclose(data_gen[1][0], - np.array([[[4], [6], [8], [10], [12]], - [[6], [8], [10], [12], [14]]]))) - assert (np.allclose(data_gen[1][1], - np.array([[14], [16]]))) - - data_gen = sequence.TimeseriesGenerator(data, targets, - length=10, - sampling_rate=2, - start_index=10, - end_index=30, - batch_size=2) - assert len(data_gen) == 6 - assert (np.allclose(data_gen[0][0], - np.array([[[10], [12], [14], [16], [18]], - [[11], [13], [15], [17], [19]]]))) - assert (np.allclose(data_gen[0][1], - np.array([[20], [21]]))) + assert (np.array_equal(data_gen[1][0], + np.array([[[4], [6], [8], [10], [12]], + [[6], [8], [10], [12], [14]]]))) + assert (np.array_equal(data_gen[1][1], + np.array([[14], [16]]))) + + data_gen = TimeseriesGenerator(data, targets, + length=10, sampling_rate=2, + start_index=10, end_index=30, + batch_size=2, gap=2) + assert len(data_gen) == 5 + assert (np.array_equal(data_gen[0][0], + np.array([[[10], [12], [14], [16], [18]], + [[11], [13], [15], [17], [19]]]))) + assert (np.array_equal(data_gen[0][1], + np.array([[20], [21]]))) data = np.array([np.random.random_sample((1, 2, 3, 4)) for i in range(50)]) targets = np.array([np.random.random_sample((3, 2, 1)) for i in range(50)]) - data_gen = sequence.TimeseriesGenerator(data, targets, - length=10, - sampling_rate=2, - start_index=10, - end_index=30, - batch_size=2) - assert len(data_gen) == 6 - assert np.allclose(data_gen[0][0], np.array( + data_gen = TimeseriesGenerator(data, targets, + length=10, sampling_rate=2, + start_index=10, end_index=30, + batch_size=2, gap=2) + assert len(data_gen) == 5 + assert np.array_equal(data_gen[0][0], np.array( [np.array(data[10:19:2]), np.array(data[11:20:2])])) - assert (np.allclose(data_gen[0][1], - np.array([targets[20], targets[21]]))) + assert (np.array_equal(data_gen[0][1], + np.array([targets[20], targets[21]]))) + + +def test_TimeseriesGenerator_exceptions(): + + data = np.array([[i] for i in range(50)]) + + with assert_raises(ValueError) as context: + TimeseriesGenerator(data, data, length=50, stride=0) + error = str(context.exception) + print(error) + assert 'must be strictly positive.' in error + + with assert_raises(ValueError) as context: + TimeseriesGenerator(data, data, length=50, sampling_rate=0) + error = str(context.exception) + print(error) + assert 'must be strictly positive.' in error + + with assert_raises(ValueError) as context: + TimeseriesGenerator(data, data, length=50, batch_size=0) + error = str(context.exception) + print(error) + assert 'must be strictly positive.' in error with assert_raises(ValueError) as context: - sequence.TimeseriesGenerator(data, targets, length=50) + TimeseriesGenerator(data, data, length=50, start_index=50) error = str(context.exception) - assert '`start_index+length=50 > end_index=49` is disallowed' in error + print(error) + assert 'This configuration gives no output' in error + + with assert_raises(ValueError) as context: + TimeseriesGenerator(data, data, length=50, sampling_rate=51) + error = str(context.exception) + print(error) + assert "`length` has to be a multiple of `sampling_rate`." + " For instance, `length=102` would do." in error + + with assert_raises(ValueError) as context: + TimeseriesGenerator(data, data, length=10, sampling_rate=3) + error = str(context.exception) + print(error) + assert "`length` has to be a multiple of `sampling_rate`." + " For instance, `length=6` would do." in error -def test_TimeSeriesGenerator_doesnt_miss_any_sample(): +def test_TimeSeriesGenerator_doesnt_miss_any_sample1(): x = np.array([[i] for i in range(10)]) - for length in range(3, 10): - g = sequence.TimeseriesGenerator(x, x, - length=length, - batch_size=1) - expected = max(0, len(x) - length) - actual = len(g) + for gap in range(10): + for length in range(1, 11 - gap): + + expected = len(x) - length + 1 - gap + + if expected > 0: + g = TimeseriesGenerator(x, x, + length=length, + batch_size=1, gap=gap) - assert expected == actual + actual = len(g) + assert expected == actual - if len(g) > 0: - # All elements in range(length, 10) should be used as current step - expected = np.arange(length, 10).reshape(-1, 1) + x = np.array([i for i in range(7)]) - y = np.concatenate([g[ix][1] for ix in range(len(g))], axis=0) - assert_allclose(y, expected) + g = TimeseriesGenerator(x, x, hlength=3, batch_size=2) + + expected_len = ceil((len(x) - g.hlength + 1.0) / g.batch_size) + print('gap: %i, hlength: %i, expected-len:%i, len: %i' % + (g.gap, g.hlength, expected_len, g.len)) + # for i in range(len(g)): + # print(i, g[i]) + + assert len(g) == expected_len + + +def test_TimeSeriesGenerator_doesnt_miss_any_sample2(): x = np.array([[i] for i in range(23)]) @@ -225,25 +267,25 @@ def test_TimeSeriesGenerator_doesnt_miss_any_sample(): lengths, batch_sizes, shuffles): - g = sequence.TimeseriesGenerator(x, x, - length=length, - sampling_rate=1, - stride=stride, - start_index=0, - end_index=None, - shuffle=shuffle, - reverse=False, - batch_size=batch_size) - if shuffle: - # all batches have the same size when shuffle is True. - expected_sequences = ceil( - (23 - length) / float(batch_size * stride)) * batch_size - else: - # last batch will be different if `(samples - length) / stride` - # is not a multiple of `batch_size`. - expected_sequences = ceil((23 - length) / float(stride)) + g = TimeseriesGenerator(x, x, + length=length, + sampling_rate=1, + stride=stride, + start_index=0, + end_index=None, + shuffle=shuffle, + reverse=False, + batch_size=batch_size) + + # last batch will be different if `(samples - length) / stride` + # is not a multiple of `batch_size`. + expected_sequences = int(ceil((len(x) - length + 1.0) / stride)) expected_batches = ceil(expected_sequences / float(batch_size)) + print('gap: %i, hlength: %i, expected-len:%i, len: %i' % + (g.gap, g.hlength, expected_batches, g.len)) + for i in range(len(g)): + print(i, g[i]) y = [g[ix][1] for ix in range(len(g))] @@ -254,5 +296,155 @@ def test_TimeSeriesGenerator_doesnt_miss_any_sample(): assert expected_batches == actual_batches +def test_TimeseriesGenerator_types(): + + print("** test 0 (float types)") + + data = np.array([[i] for i in range(50)], dtype=np.float) + targets = np.array([[float(i)] for i in range(50)]) + + data_gen = TimeseriesGenerator(data, targets, + hlength=5, sampling_rate=2, gap=2, + batch_size=2, shuffle=False) + x, y = data_gen[0] + + assert np.allclose(x, np.array([[[0], [2], [4], [6], [8]], + [[1], [3], [5], [7], [9]]])) + assert np.allclose(y, np.array([[10], [11]])) + + print("** test 1 (auto types)") + + data = np.array([[i] for i in range(50)], dtype=np.float) + targets = np.array([[i] for i in range(50)], dtype=np.float) + + data_gen = TimeseriesGenerator(data, targets, + hlength=5, sampling_rate=2, gap=2, + batch_size=2, shuffle=False) + x, y = data_gen[0] + assert len(data_gen) == 20 + assert np.array_equal(x, np.array([[[0], [2], [4], [6], [8]], + [[1], [3], [5], [7], [9]]])) + assert np.array_equal(y, np.array([[10], [11]])) + + x, y = data_gen[-1] + + assert np.array_equal(x, np.array([[[38], [40], [42], [44], [46]], + [[39], [41], [43], [45], [47]]])) + assert np.array_equal(y, np.array([[48], [49]])) + + print("** test 2 (batch_size=4)") + data_gen = TimeseriesGenerator( + data, targets, hlength=10, batch_size=4, gap=1) + assert len(data_gen) == 10 + x, y = data_gen[0] + assert np.array_equal(x[1], np.array( + [[1], [2], [3], [4], [5], [6], [7], [8], [9], [10]])) + assert np.array_equal(y, np.array([[10], [11], [12], [13]])) + + data_gen = TimeseriesGenerator( + data, targets, hlength=10, reverse=True, batch_size=2) + x, y = data_gen[0] + assert np.array_equal(x[1, 0], np.array([10])) + + print("** test 3 (when sampling_rate is not a multiple of hlength)") + data_gen = TimeseriesGenerator( + data, targets, hlength=10, sampling_rate=3, batch_size=2) + + # for i in range(len(data_gen)): + # print(i,data_gen[i]) + + assert len(data_gen) == 12 + + print("** test 4 (stateful)") + data_gen = TimeseriesGenerator( + data, targets, hlength=10, sampling_rate=2, batch_size=5, stateful=True, + gap=2, stride=4) + + +def test_TimeseriesGenerator_on_text(): + + txt = bytearray("Keras is simple.", 'utf-8') + data_gen = TimeseriesGenerator(txt, txt, hlength=10, batch_size=1, gap=1) + + # for i in range(len(data_gen)): + # print(data_gen[i][0].tostring(), "->'%s'" % data_gen[i][1].tostring()) + + assert data_gen[-1][0].shape == (1, 10) and data_gen[-1][1].shape == (1,) + assert data_gen[-1][0].tostring() == b" is simple" + assert data_gen[-1][1].tostring() == b"." + + data_gen = TimeseriesGenerator( + txt, txt, hlength=10, target_seq=True, batch_size=1, gap=1) + + assert data_gen[-1][0].shape == (1, + 10) and data_gen[-1][1].shape == (1, 10, 1) + # for i in range(len(data_gen)): + # print(data_gen[i][0].tostring(), "->'%s'" % data_gen[i][1].tostring()) + + assert data_gen[0][1].tostring() == b"eras is si" + + +def test_TimeseriesGenerator_previous_tests(): + + data = np.array([[i] for i in range(50)]) + + data_gen = TimeseriesGenerator(data, data, + length=10, sampling_rate=2, reverse=True, + batch_size=2, gap=2) + assert len(data_gen) == 20 + assert (np.allclose(data_gen[0][0], + np.array([[[8], [6], [4], [2], [0]], + [[9], [7], [5], [3], [1]]]))) + assert (np.allclose(data_gen[0][1], + np.array([[10], [11]]))) + + data_gen = TimeseriesGenerator(data, data, + length=10, sampling_rate=2, shuffle=True, + batch_size=1, gap=2) + batch = data_gen[0] + r = batch[1][0][0] + assert (np.allclose(batch[0], + np.array([[[r - 10], + [r - 8], + [r - 6], + [r - 4], + [r - 2]]]))) + assert (np.allclose(batch[1], np.array([[r], ]))) + + data_gen = TimeseriesGenerator(data, data, + length=10, sampling_rate=2, stride=2, + batch_size=2, gap=2) + assert len(data_gen) == 10 + assert (np.allclose(data_gen[1][0], + np.array([[[4], [6], [8], [10], [12]], + [[6], [8], [10], [12], [14]]]))) + assert (np.allclose(data_gen[1][1], + np.array([[14], [16]]))) + + data_gen = TimeseriesGenerator(data, data, + length=10, sampling_rate=2, + start_index=10, end_index=30, + batch_size=2, gap=2) + assert len(data_gen) == 5 + assert (np.allclose(data_gen[0][0], + np.array([[[10], [12], [14], [16], [18]], + [[11], [13], [15], [17], [19]]]))) + assert (np.allclose(data_gen[0][1], + np.array([[20], [21]]))) + + data = np.array([np.random.random_sample((1, 2, 3, 4)) for i in range(50)]) + targets = np.array([np.random.random_sample((3, 2, 1)) for i in range(50)]) + data_gen = TimeseriesGenerator(data, targets, + length=10, sampling_rate=2, + start_index=10, end_index=30, + batch_size=2, gap=2) + + assert len(data_gen) == 5 + assert np.allclose(data_gen[0][0], np.array( + [np.array(data[10:19:2]), np.array(data[11:20:2])])) + assert (np.allclose(data_gen[0][1], + np.array([targets[20], targets[21]]))) + + if __name__ == '__main__': pytest.main([__file__])