-
Notifications
You must be signed in to change notification settings - Fork 216
[ENH] Added BaseDeepForecaster in forecasting/deep_learning #2905
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
af94a86
7bb161e
d2ee9ec
ab3030c
1f202db
14eb41f
d1a2aab
865ed14
30d862a
78b2f3d
d1a7fd0
b6ccd07
a39fafb
f7fd5bd
405fa80
004afed
c0ca211
bc1adba
b25059d
fdf5b6d
51f655a
f3af433
003badc
de01961
7ce3e68
50dbec6
a19c8ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""Initialization for aeon forecasting deep learning module.""" | ||
|
||
__all__ = [ | ||
"BaseDeepForecaster", | ||
] | ||
|
||
from aeon.forecasting.deep_learning.base import BaseDeepForecaster |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
"""Base class module for deep learning forecasters in aeon. | ||
|
||
This module defines the `BaseDeepForecaster` class, an abstract base class for | ||
deep learning-based forecasting models within the aeon toolkit. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
__maintainer__ = [] | ||
__all__ = ["BaseDeepForecaster"] | ||
|
||
from abc import abstractmethod | ||
from typing import Any | ||
|
||
from aeon.forecasting.base import BaseForecaster | ||
|
||
|
||
class BaseDeepForecaster(BaseForecaster): | ||
"""Base class for deep learning forecasters in aeon. | ||
|
||
This class provides a foundation for deep learning-based forecasting models, | ||
handling data preprocessing, model training, and prediction with enhanced | ||
capabilities for callbacks, model saving/loading, and efficiency. | ||
|
||
Parameters | ||
---------- | ||
window : int, | ||
The window size for creating input sequences. | ||
horizon : int, default=1 | ||
Forecasting horizon, the number of steps ahead to predict. | ||
verbose : int, default=0 | ||
Verbosity mode (0, 1, or 2). | ||
lucifer4073 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
callbacks : list of tf.keras.callbacks.Callback or None, default=None | ||
List of Keras callbacks to be applied during training. | ||
axis : int, default=0 | ||
Axis along which to apply the forecaster. | ||
last_file_name : str, default="last_model" | ||
The name of the file of the last model, used for saving models. | ||
file_path : str, default="./" | ||
Directory path where models will be saved. | ||
|
||
Attributes | ||
---------- | ||
model_ : tf.keras.Model or None | ||
The fitted Keras model. | ||
history_ : tf.keras.callbacks.History or None | ||
Training history containing loss and metrics. | ||
last_window_ : np.ndarray or None | ||
lucifer4073 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The last window of data used for prediction. | ||
""" | ||
|
||
_tags = { | ||
"capability:horizon": True, | ||
"capability:exogenous": False, | ||
"algorithm_type": "deeplearning", | ||
"non_deterministic": True, | ||
"cant_pickle": True, | ||
"python_dependencies": "tensorflow", | ||
"capability:multivariate": False, | ||
} | ||
|
||
def __init__( | ||
self, | ||
window, | ||
horizon=1, | ||
verbose=0, | ||
callbacks=None, | ||
axis=0, | ||
last_file_name="last_model", | ||
file_path="./", | ||
): | ||
self.horizon = horizon | ||
self.window = window | ||
self.verbose = verbose | ||
self.callbacks = callbacks | ||
self.axis = axis | ||
self.last_file_name = last_file_name | ||
self.file_path = file_path | ||
|
||
self.model_ = None | ||
self.history_ = None | ||
self.last_window_ = None | ||
|
||
super().__init__(horizon=horizon, axis=axis) | ||
|
||
def _fit(self, y, exog=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the default fit in base just returns self. Is the intention to require it for BaseDeepForecasters? |
||
"""Fit the model.""" | ||
... | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. abstract methods should be annotated abstract |
||
def _predict(self, y, exog=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is the same as the base class predict, so isnt needed I dont think |
||
"""Predict using the model.""" | ||
... | ||
|
||
def _prepare_callbacks(self): | ||
"""Prepare callbacks for training. | ||
|
||
Returns | ||
------- | ||
callbacks_list : list | ||
List of callbacks to be used during training. | ||
""" | ||
callbacks_list = [] | ||
if self.callbacks is not None: | ||
if isinstance(self.callbacks, list): | ||
callbacks_list.extend(self.callbacks) | ||
else: | ||
callbacks_list.append(self.callbacks) | ||
|
||
callbacks_list = self._get_model_checkpoint_callback( | ||
callbacks_list, self.file_path, "best_model" | ||
) | ||
return callbacks_list | ||
|
||
def _get_model_checkpoint_callback(self, callbacks, file_path, file_name): | ||
"""Add model checkpoint callback to save the best model. | ||
|
||
Parameters | ||
---------- | ||
callbacks : list | ||
Existing list of callbacks. | ||
file_path : str | ||
Directory path where the model will be saved. | ||
file_name : str | ||
Name of the model file. | ||
|
||
Returns | ||
------- | ||
callbacks : list | ||
Updated list of callbacks including ModelCheckpoint. | ||
""" | ||
import tensorflow as tf | ||
|
||
model_checkpoint_ = tf.keras.callbacks.ModelCheckpoint( | ||
filepath=file_path + file_name + ".keras", | ||
monitor="loss", | ||
save_best_only=True, | ||
verbose=self.verbose, | ||
) | ||
if isinstance(callbacks, list): | ||
return callbacks + [model_checkpoint_] | ||
else: | ||
return [callbacks] + [model_checkpoint_] | ||
|
||
def summary(self): | ||
"""Summary function to return the losses/metrics for model fit. | ||
|
||
Returns | ||
------- | ||
history : dict or None | ||
Dictionary containing model's train/validation losses and metrics. | ||
""" | ||
return self.history_.history if self.history_ is not None else None | ||
|
||
def save_last_model_to_file(self, file_path="./"): | ||
"""Save the last epoch of the trained deep learning model. | ||
|
||
Parameters | ||
---------- | ||
file_path : str, default="./" | ||
The directory where the model will be saved. | ||
|
||
Returns | ||
------- | ||
None | ||
""" | ||
import os | ||
|
||
if self.model_ is None: | ||
raise ValueError("No model to save. Please fit the model first.") | ||
self.model_.save(os.path.join(file_path, self.last_file_name + ".keras")) | ||
|
||
def load_model(self, model_path): | ||
"""Load a pre-trained keras model instead of fitting. | ||
|
||
When calling this function, all functionalities can be used | ||
such as predict with the loaded model. | ||
|
||
Parameters | ||
---------- | ||
model_path : str | ||
Path to the saved model file including extension. | ||
Example: model_path="path/to/file/best_model.keras" | ||
|
||
Returns | ||
------- | ||
None | ||
""" | ||
import tensorflow as tf | ||
|
||
self.model_ = tf.keras.models.load_model(model_path) | ||
self.is_fitted = True | ||
|
||
@abstractmethod | ||
def build_model(self, input_shape): | ||
"""Build the deep learning model. | ||
|
||
Parameters | ||
---------- | ||
input_shape : tuple | ||
Shape of input data. | ||
|
||
Returns | ||
------- | ||
model : tf.keras.Model | ||
Compiled Keras model. | ||
""" | ||
pass | ||
lucifer4073 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@classmethod | ||
def _get_test_params( | ||
cls, parameter_set: str = "default" | ||
) -> dict[str, Any] | list[dict[str, Any]]: | ||
""" | ||
Return testing parameter settings for the estimator. | ||
|
||
Parameters | ||
---------- | ||
parameter_set : str, default="default" | ||
Name of the set of test parameters to return, for use in tests. | ||
|
||
Returns | ||
------- | ||
params : dict or list of dict, default={} | ||
Parameters to create testing instances of the class. | ||
""" | ||
param = { | ||
"window": 10, | ||
} | ||
return [param] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Deep Learning Forecasting Tests File.""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""Test file for BaseDeepForecaster.""" | ||
|
||
import pytest | ||
|
||
from aeon.forecasting.deep_learning.base import BaseDeepForecaster | ||
from aeon.utils.validation._dependencies import _check_soft_dependencies | ||
|
||
|
||
class DummyDeepForecaster(BaseDeepForecaster): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doesnt implement abstract methods |
||
"""Minimal concrete subclass to allow instantiation.""" | ||
|
||
def __init__(self, window): | ||
super().__init__(window=window) | ||
|
||
def build_model(self, input_shape): | ||
"""Construct and return a model based on the provided input shape.""" | ||
return None # Not needed for this test | ||
|
||
|
||
@pytest.mark.skipif( | ||
not _check_soft_dependencies(["tensorflow"], severity="none"), | ||
reason="Tensorflow soft dependency unavailable.", | ||
) | ||
def test_default_init_attributes(): | ||
"""Test that BaseDeepForecaster sets default params and attributes correctly.""" | ||
forecaster = DummyDeepForecaster(window=10) | ||
|
||
# check default parameters | ||
assert forecaster.horizon == 1 | ||
assert forecaster.window == 10 | ||
assert forecaster.verbose == 0 | ||
assert forecaster.callbacks is None | ||
assert forecaster.axis == 0 | ||
assert forecaster.last_file_name == "last_model" | ||
assert forecaster.file_path == "./" | ||
|
||
# check default attributes after init | ||
assert forecaster.model_ is None | ||
assert forecaster.history_ is None | ||
assert forecaster.last_window_ is None | ||
|
||
# check tags | ||
tags = forecaster.get_tags() | ||
assert tags["algorithm_type"] == "deeplearning" | ||
assert tags["capability:horizon"] | ||
assert tags["capability:univariate"] |
Uh oh!
There was an error while loading. Please reload this page.