-
Notifications
You must be signed in to change notification settings - Fork 437
New TimeseriesGenerator #7
Changes from all commits
8514b36
9a0a403
2308f06
aa573f9
5e4b35e
04b019d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,10 +8,12 @@ | |
| import numpy as np | ||
| import random | ||
| from six.moves import range | ||
| from math import ceil | ||
|
|
||
| from . import get_keras_submodule | ||
|
|
||
| keras_utils = get_keras_submodule('utils') | ||
| from keras.utils.data_utils import Sequence | ||
|
|
||
|
|
||
| def pad_sequences(sequences, maxlen=None, dtype='int32', | ||
|
|
@@ -213,7 +215,7 @@ def skipgrams(sequence, vocabulary_size, | |
| random.shuffle(words) | ||
|
|
||
| couples += [[words[i % len(words)], | ||
| random.randint(1, vocabulary_size - 1)] | ||
| random.randint(1, vocabulary_size - 1)] | ||
| for i in range(num_negative_samples)] | ||
| if categorical: | ||
| labels += [[1, 0]] * num_negative_samples | ||
|
|
@@ -250,121 +252,245 @@ def _remove_long_seq(maxlen, seq, label): | |
| return new_seq, new_label | ||
|
|
||
|
|
||
| class TimeseriesGenerator(keras_utils.Sequence): | ||
| class TimeseriesGenerator(Sequence): | ||
| """Utility class for generating batches of temporal data. | ||
|
|
||
| This class takes in a sequence of data-points gathered at | ||
| equal intervals, along with time series parameters such as | ||
| stride, length of history, etc., to produce batches for | ||
| training/validation. | ||
|
|
||
| # Arguments | ||
| data: Indexable generator (such as list or Numpy array) | ||
| containing consecutive data points (timesteps). | ||
| The data should be at 2D, and axis 0 is expected | ||
| to be the time dimension. | ||
| The data should be convertible into a 1D numpy array, | ||
| if 2D or more, axis 0 is expected to be the time dimension. | ||
| targets: Targets corresponding to timesteps in `data`. | ||
| It should have same length as `data`. | ||
| length: Length of the output sequences (in number of timesteps). | ||
| It should have at least the same length as `data`. | ||
| length: length of the output sub-sequence before sampling | ||
| (depreciated, use hlength instead). | ||
| sampling_rate: Period between successive individual timesteps | ||
| within sequences. For rate `r`, timesteps | ||
| `data[i]`, `data[i-r]`, ... `data[i - length]` | ||
| are used for create a sample sequence. | ||
| within sequences, `length` has to be a multiple of `sampling_rate`. | ||
| stride: Period between successive output sequences. | ||
| For stride `s`, consecutive output samples would | ||
| be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc. | ||
| start_index: Data points earlier than `start_index` will not be used | ||
| in the output sequences. This is useful to reserve part of the | ||
| data for test or validation. | ||
| end_index: Data points later than `end_index` will not be used | ||
| in the output sequences. This is useful to reserve part of the | ||
| data for test or validation. | ||
| start_index, end_index: Data points earlier than `start_index` | ||
| or later than `end_index` will not be used in the output sequences. | ||
| This is useful to reserve part of the data for test or validation. | ||
| shuffle: Whether to shuffle output samples, | ||
| or instead draw them in chronological order. | ||
| reverse: Boolean: if `true`, timesteps in each output sample will be | ||
| reverse: Boolean: if `True`, timesteps in each output sample will be | ||
| in reverse chronological order. | ||
| batch_size: Number of timeseries samples in each batch | ||
| (except maybe the last one). | ||
| batch_size: Number of timeseries samples in each batch. | ||
| hlength: Effective "history" length of the output sub-sequences after | ||
| sampling (in number of timesteps). | ||
| gap: prediction gap, i.e. numer of timesteps ahead (usually zero, or | ||
| same as samplig_rate) | ||
| `x=data[i - (hlength-1)*sampling_rate - gap:i-gap+1:sampling_rate]` | ||
| and `y=targets[i]` | ||
| are used respectively as sample sequence `x` and target value `y`. | ||
| target_seq: Boolean: if 'True', produces full shifted sequence targets: | ||
| If target_seq is set, for sampling rate `r`, timesteps | ||
| `data[i - (hlength-1)*r - gap]`, ..., `data[i-r-gap]`, `data[i-gap]` | ||
| and | ||
| `targets[i - (hlength-1)*r]`, ..., `data[i-r]`, `data[i]` | ||
| are used respectively as sample sequence `x` and target sequence `y`. | ||
| dtype: force sample/target dtype (default is None) | ||
| stateful: helper to check if parameters are valid for stateful learning | ||
| (experimental). | ||
|
|
||
|
|
||
| # Returns | ||
| A [Sequence](/utils/#sequence) instance. | ||
| A [Sequence](/utils/#sequence) instance of tuples (x,y) | ||
| where x is a numpy array of shape (batch_size, hlength, ...) | ||
| and y is a numpy array of shape (batch_size, ...) if target_seq is `False` | ||
| or (batch_size, hlength, ...) if target_seq is `True`. | ||
| If not specified, output dtype is infered from data dtype. | ||
|
|
||
| # Examples | ||
|
|
||
| ```python | ||
| from keras.preprocessing.sequence import TimeseriesGenerator | ||
| import numpy as np | ||
|
|
||
| data = np.array([[i] for i in range(50)]) | ||
| targets = np.array([[i] for i in range(50)]) | ||
|
|
||
| data_gen = TimeseriesGenerator(data, targets, | ||
| length=10, sampling_rate=2, | ||
| batch_size=2) | ||
| assert len(data_gen) == 20 | ||
|
|
||
| batch_0 = data_gen[0] | ||
| x, y = batch_0 | ||
| assert np.array_equal(x, | ||
| np.array([[[0], [2], [4], [6], [8]], | ||
| [[1], [3], [5], [7], [9]]])) | ||
| assert np.array_equal(y, | ||
| np.array([[10], [11]])) | ||
| txt = bytearray("Keras is simple.", 'utf-8') | ||
| data_gen = TimeseriesGenerator(txt, txt, hlength=10, batch_size=1, gap=1) | ||
|
|
||
| for i in range(len(data_gen)): | ||
| print(data_gen[i][0].tostring(), "->'%s'" % data_gen[i][1].tostring()) | ||
|
|
||
| assert data_gen[-1][0].shape == (1, 10) and data_gen[-1][1].shape == (1,) | ||
| assert data_gen[-1][0].tostring() == u" is simple" | ||
| assert data_gen[-1][1].tostring() == u"." | ||
|
|
||
| t = np.linspace(0,20*np.pi, num=1000) # time | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. space after "," |
||
| x = np.sin(np.cos(3*t)) # input signa | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. signal* |
||
| y = np.sin(np.cos(6*t+4)) # output signal | ||
|
|
||
| # define recurrent model | ||
| from keras.models import Model | ||
| from keras.layers import Input, SimpleRNN, LSTM, GRU,Dense | ||
|
|
||
| inputs = Input(batch_shape=(None, None, 1)) | ||
| l = SimpleRNN(100, return_sequences=True)(inputs) | ||
| l = Dense(100, activation='tanh')(l) | ||
| preds = Dense(1, activation='linear')(l) | ||
| model = Model(inputs=inputs, outputs=preds) | ||
| model.compile(loss='mean_squared_error', optimizer='Nadam') | ||
|
|
||
| # fit model to sequence | ||
| xx = np.expand_dims(x, axis=-1) | ||
| g = TimeseriesGenerator(xx, y, hlength=100, target_seq=True, shuffle=True) | ||
| model.fit_generator(g, steps_per_epoch=len(g), epochs=20, shuffle=True) | ||
|
|
||
| # plot prediction | ||
| x2 = np.reshape(x,(1,x.shape[0],1)) | ||
| z = model.predict(x2) | ||
|
|
||
| import matplotlib.pyplot as plt | ||
| plt.figure(figsize=(12,12)) | ||
| plt.title('Phase representation') | ||
| plt.plot(x,y.flatten(), color='black') | ||
| plt.plot(x,z.flatten(), dashes=[8,1], label='prediction', color='orange') | ||
| plt.xlabel('input') | ||
| plt.ylabel('output') | ||
| plt.grid() | ||
| plt.show() | ||
|
|
||
| ``` | ||
| """ | ||
|
|
||
| def __init__(self, data, targets, length, | ||
| def __init__(self, data, targets, length=None, | ||
| sampling_rate=1, | ||
| stride=1, | ||
| start_index=0, | ||
| end_index=None, | ||
| start_index=0, end_index=None, | ||
| shuffle=False, | ||
| reverse=False, | ||
| batch_size=128): | ||
| self.data = data | ||
| self.targets = targets | ||
| self.length = length | ||
| batch_size=128, | ||
| hlength=None, | ||
| target_seq=False, | ||
| gap=0, | ||
| dtype=None, | ||
| stateful=False): | ||
|
|
||
| # Sanity check | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove stray newline |
||
| if sampling_rate <= 0: | ||
| raise ValueError('`sampling_rate` must be strictly positive.') | ||
| if stride <= 0: | ||
| raise ValueError('`stride` must be strictly positive.') | ||
| if batch_size <= 0: | ||
| raise ValueError('`batch_size` must be strictly positive.') | ||
| if len(data) > len(targets): | ||
| raise ValueError('`targets` has to be at least as long as `data`.') | ||
|
|
||
| if hlength is None: | ||
| if length % sampling_rate != 0: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a DeprecationWarning |
||
| raise ValueError( | ||
| "`length` has to be a multiple of `sampling_rate`." | ||
| " For instance, `length=%i` would do." % (2 * sampling_rate)) | ||
| hlength = length // sampling_rate | ||
|
|
||
| if gap % sampling_rate != 0: | ||
| warnings.warn( | ||
| "Unless you know what you do, `gap` should be zero or" | ||
| " a multiple of `sampling_rate`.", UserWarning) | ||
|
|
||
| self.hlength = hlength | ||
| assert self.hlength > 0 | ||
|
|
||
| self.data = np.asarray(data) | ||
| self.targets = np.asarray(targets) | ||
|
|
||
| # FIXME: targets must be 2D for sequences output | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a new limitation or was it always like that? |
||
| if target_seq and len(self.targets.shape) < 2: | ||
| self.targets = np.expand_dims(self.targets, axis=-1) | ||
|
|
||
| if dtype is None: | ||
| self.data_type = self.data.dtype | ||
| self.targets_type = self.targets.dtype | ||
| else: | ||
| self.data_type = dtype | ||
| self.targets_type = dtype | ||
|
|
||
| # Check if parameters are stateful-compatible | ||
| if stateful: | ||
| if shuffle: | ||
| raise ValueError('Do not shuffle for stateful learning.') | ||
| if self.hlength % batch_size != 0: | ||
| raise ValueError("For stateful learning, `hlength` has to be" | ||
| "a multiple of `batch_size`." | ||
| "For instance, `hlength=%i` would do." | ||
| % (3 * batch_size)) | ||
| if stride != (self.hlength // batch_size) * sampling_rate: | ||
| raise ValueError( | ||
| '`stride=%i`, for these parameters set `stride=%i`.' | ||
| % (stride, (hlength // batch_size) * sampling_rate)) | ||
|
|
||
| self.sampling_rate = sampling_rate | ||
| self.batch_size = batch_size | ||
| assert stride > 0 | ||
| self.stride = stride | ||
| self.start_index = start_index + length | ||
| self.gap = gap | ||
|
|
||
| sliding_win_size = (self.hlength - 1) * sampling_rate + gap | ||
| self.start_index = start_index + sliding_win_size | ||
| if end_index is None: | ||
| end_index = len(data) - 1 | ||
| end_index = len(data) | ||
| assert end_index <= len(data) | ||
| self.end_index = end_index | ||
| self.shuffle = shuffle | ||
| self.reverse = reverse | ||
| self.batch_size = batch_size | ||
| self.target_seq = target_seq | ||
|
|
||
| self.len = int(ceil(float(self.end_index - self.start_index) / | ||
| (self.batch_size * self.stride))) | ||
| if self.len <= 0: | ||
| err = "This configuration gives no output, try with a longer" | ||
| " input sequence or different parameters." | ||
| raise ValueError(err) | ||
|
|
||
| assert self.len > 0 | ||
|
|
||
| if self.start_index > self.end_index: | ||
| raise ValueError('`start_index+length=%i > end_index=%i` ' | ||
| 'is disallowed, as no part of the sequence ' | ||
| 'would be left to be used as current step.' | ||
| % (self.start_index, self.end_index)) | ||
| self.perm = np.arange(self.start_index, self.end_index) | ||
| if shuffle: | ||
| np.random.shuffle(self.perm) | ||
|
|
||
| def __len__(self): | ||
| return (self.end_index - self.start_index + | ||
| self.batch_size * self.stride) // (self.batch_size * self.stride) | ||
| return self.len | ||
|
|
||
| def _empty_batch(self, num_rows): | ||
| samples_shape = [num_rows, self.length // self.sampling_rate] | ||
| samples_shape = [num_rows, self.hlength] | ||
| samples_shape.extend(self.data.shape[1:]) | ||
| targets_shape = [num_rows] | ||
| if self.target_seq: | ||
| targets_shape = [num_rows, self.hlength] | ||
| else: | ||
| targets_shape = [num_rows] | ||
| targets_shape.extend(self.targets.shape[1:]) | ||
| return np.empty(samples_shape), np.empty(targets_shape) | ||
|
|
||
| return np.empty(samples_shape, dtype=self.data_type), np.empty( | ||
| targets_shape, dtype=self.targets_type) | ||
|
|
||
| def __getitem__(self, index): | ||
| if self.shuffle: | ||
| rows = np.random.randint( | ||
| self.start_index, self.end_index + 1, size=self.batch_size) | ||
| else: | ||
| i = self.start_index + self.batch_size * self.stride * index | ||
| rows = np.arange(i, min(i + self.batch_size * | ||
| self.stride, self.end_index + 1), self.stride) | ||
| while index < 0: | ||
| index += self.len | ||
| assert index < self.len | ||
| batch_start = self.batch_size * self.stride * index | ||
| rows = np.arange(batch_start, min(batch_start + self.batch_size * | ||
| self.stride, | ||
| self.end_index - self.start_index), | ||
| self.stride) | ||
| rows = self.perm[rows] | ||
|
|
||
| samples, targets = self._empty_batch(len(rows)) | ||
| for j, row in enumerate(rows): | ||
| indices = range(rows[j] - self.length, rows[j], self.sampling_rate) | ||
| indices = range(rows[j] - self.gap - | ||
| (self.hlength - 1) * self.sampling_rate, | ||
| rows[j] - self.gap + 1, self.sampling_rate) | ||
| samples[j] = self.data[indices] | ||
| targets[j] = self.targets[rows[j]] | ||
| if self.target_seq: | ||
| shifted_indices = range(rows[j] - (self.hlength - 1) * | ||
| self.sampling_rate, | ||
| rows[j] + 1, self.sampling_rate) | ||
| targets[j] = self.targets[shifted_indices] | ||
| else: | ||
| targets[j] = self.targets[rows[j]] | ||
| if self.reverse: | ||
| return samples[:, ::-1, ...], targets | ||
| return samples, targets | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Might make this a little easier to understand if we say "and* axis 0"