Skip to content

[ENH] Added BaseDeepForecaster in forecasting/deep_learning #2905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
af94a86
Basedeep forecaster added
lucifer4073 May 22, 2025
7bb161e
Merge upstream main to basedlf
lucifer4073 May 24, 2025
d2ee9ec
init for basedlf added
lucifer4073 May 26, 2025
ab3030c
test file and axis added for basedeepforecaster
lucifer4073 Jun 15, 2025
1f202db
test locally
lucifer4073 Jun 15, 2025
14eb41f
dlf corrected
lucifer4073 Jun 15, 2025
d1a2aab
tf soft dep added
lucifer4073 Jun 22, 2025
865ed14
Merge remote-tracking branch 'upstream/main' into basedlf
lucifer4073 Jun 22, 2025
30d862a
base fst changed
lucifer4073 Jul 8, 2025
78b2f3d
test file corrected
lucifer4073 Jul 8, 2025
d1a7fd0
Merge branch 'main' into basedlf
lucifer4073 Jul 8, 2025
b6ccd07
basedelf updated
lucifer4073 Jul 20, 2025
a39fafb
Merge branch 'basedlf' of https://github.com/lucifer4073/aeon into ba…
lucifer4073 Jul 20, 2025
f7fd5bd
Merge branch 'main' into basedlf
lucifer4073 Jul 20, 2025
405fa80
test base chanegd
lucifer4073 Jul 20, 2025
004afed
Merge branch 'main' of https://github.com/aeon-toolkit/aeon into basedlf
lucifer4073 Aug 11, 2025
c0ca211
Merge remote-tracking branch 'upstream/main' into basedlf
lucifer4073 Aug 17, 2025
bc1adba
current basedlf
lucifer4073 Aug 17, 2025
b25059d
base changed
lucifer4073 Aug 19, 2025
fdf5b6d
Merge branch 'main' into basedlf
lucifer4073 Aug 19, 2025
51f655a
Merge branch 'main' into basedlf
lucifer4073 Aug 19, 2025
f3af433
save best model changed
lucifer4073 Aug 19, 2025
003badc
Merge branch 'basedlf' of https://github.com/lucifer4073/aeon into ba…
lucifer4073 Aug 19, 2025
de01961
Merge remote-tracking branch 'upstream/main' into basedlf
lucifer4073 Aug 19, 2025
7ce3e68
conversations resolved
lucifer4073 Aug 19, 2025
50dbec6
conversations resolved
lucifer4073 Aug 19, 2025
a19c8ab
Merge branch 'main' into basedlf
TonyBagnall Aug 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions aeon/forecasting/deep_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Initialization for aeon forecasting deep learning module."""

__all__ = [
"BaseDeepForecaster",
]

from aeon.forecasting.deep_learning.base import BaseDeepForecaster
229 changes: 229 additions & 0 deletions aeon/forecasting/deep_learning/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
"""Base class module for deep learning forecasters in aeon.

This module defines the `BaseDeepForecaster` class, an abstract base class for
deep learning-based forecasting models within the aeon toolkit.
"""

from __future__ import annotations

__maintainer__ = []
__all__ = ["BaseDeepForecaster"]

from abc import abstractmethod
from typing import Any

from aeon.forecasting.base import BaseForecaster


class BaseDeepForecaster(BaseForecaster):
"""Base class for deep learning forecasters in aeon.

This class provides a foundation for deep learning-based forecasting models,
handling data preprocessing, model training, and prediction with enhanced
capabilities for callbacks, model saving/loading, and efficiency.

Parameters
----------
window : int,
The window size for creating input sequences.
horizon : int, default=1
Forecasting horizon, the number of steps ahead to predict.
verbose : int, default=0
Verbosity mode (0, 1, or 2).
callbacks : list of tf.keras.callbacks.Callback or None, default=None
List of Keras callbacks to be applied during training.
axis : int, default=0
Axis along which to apply the forecaster.
last_file_name : str, default="last_model"
The name of the file of the last model, used for saving models.
file_path : str, default="./"
Directory path where models will be saved.

Attributes
----------
model_ : tf.keras.Model or None
The fitted Keras model.
history_ : tf.keras.callbacks.History or None
Training history containing loss and metrics.
last_window_ : np.ndarray or None
The last window of data used for prediction.
"""

_tags = {
"capability:horizon": True,
"capability:exogenous": False,
"algorithm_type": "deeplearning",
"non_deterministic": True,
"cant_pickle": True,
"python_dependencies": "tensorflow",
"capability:multivariate": False,
}

def __init__(
self,
window,
horizon=1,
verbose=0,
callbacks=None,
axis=0,
last_file_name="last_model",
file_path="./",
):
self.horizon = horizon
self.window = window
self.verbose = verbose
self.callbacks = callbacks
self.axis = axis
self.last_file_name = last_file_name
self.file_path = file_path

self.model_ = None
self.history_ = None
self.last_window_ = None

super().__init__(horizon=horizon, axis=axis)

def _fit(self, y, exog=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default fit in base just returns self. Is the intention to require it for BaseDeepForecasters?

"""Fit the model."""
...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

abstract methods should be annotated abstract

def _predict(self, y, exog=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the same as the base class predict, so isnt needed I dont think

"""Predict using the model."""
...

def _prepare_callbacks(self):
"""Prepare callbacks for training.

Returns
-------
callbacks_list : list
List of callbacks to be used during training.
"""
callbacks_list = []
if self.callbacks is not None:
if isinstance(self.callbacks, list):
callbacks_list.extend(self.callbacks)
else:
callbacks_list.append(self.callbacks)

callbacks_list = self._get_model_checkpoint_callback(
callbacks_list, self.file_path, "best_model"
)
return callbacks_list

def _get_model_checkpoint_callback(self, callbacks, file_path, file_name):
"""Add model checkpoint callback to save the best model.

Parameters
----------
callbacks : list
Existing list of callbacks.
file_path : str
Directory path where the model will be saved.
file_name : str
Name of the model file.

Returns
-------
callbacks : list
Updated list of callbacks including ModelCheckpoint.
"""
import tensorflow as tf

model_checkpoint_ = tf.keras.callbacks.ModelCheckpoint(
filepath=file_path + file_name + ".keras",
monitor="loss",
save_best_only=True,
verbose=self.verbose,
)
if isinstance(callbacks, list):
return callbacks + [model_checkpoint_]
else:
return [callbacks] + [model_checkpoint_]

def summary(self):
"""Summary function to return the losses/metrics for model fit.

Returns
-------
history : dict or None
Dictionary containing model's train/validation losses and metrics.
"""
return self.history_.history if self.history_ is not None else None

def save_last_model_to_file(self, file_path="./"):
"""Save the last epoch of the trained deep learning model.

Parameters
----------
file_path : str, default="./"
The directory where the model will be saved.

Returns
-------
None
"""
import os

if self.model_ is None:
raise ValueError("No model to save. Please fit the model first.")
self.model_.save(os.path.join(file_path, self.last_file_name + ".keras"))

def load_model(self, model_path):
"""Load a pre-trained keras model instead of fitting.

When calling this function, all functionalities can be used
such as predict with the loaded model.

Parameters
----------
model_path : str
Path to the saved model file including extension.
Example: model_path="path/to/file/best_model.keras"

Returns
-------
None
"""
import tensorflow as tf

self.model_ = tf.keras.models.load_model(model_path)
self.is_fitted = True

@abstractmethod
def build_model(self, input_shape):
"""Build the deep learning model.

Parameters
----------
input_shape : tuple
Shape of input data.

Returns
-------
model : tf.keras.Model
Compiled Keras model.
"""
pass

@classmethod
def _get_test_params(
cls, parameter_set: str = "default"
) -> dict[str, Any] | list[dict[str, Any]]:
"""
Return testing parameter settings for the estimator.

Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests.

Returns
-------
params : dict or list of dict, default={}
Parameters to create testing instances of the class.
"""
param = {
"window": 10,
}
return [param]
1 change: 1 addition & 0 deletions aeon/forecasting/deep_learning/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Deep Learning Forecasting Tests File."""
46 changes: 46 additions & 0 deletions aeon/forecasting/deep_learning/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Test file for BaseDeepForecaster."""

import pytest

from aeon.forecasting.deep_learning.base import BaseDeepForecaster
from aeon.utils.validation._dependencies import _check_soft_dependencies


class DummyDeepForecaster(BaseDeepForecaster):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesnt implement abstract methods

"""Minimal concrete subclass to allow instantiation."""

def __init__(self, window):
super().__init__(window=window)

def build_model(self, input_shape):
"""Construct and return a model based on the provided input shape."""
return None # Not needed for this test


@pytest.mark.skipif(
not _check_soft_dependencies(["tensorflow"], severity="none"),
reason="Tensorflow soft dependency unavailable.",
)
def test_default_init_attributes():
"""Test that BaseDeepForecaster sets default params and attributes correctly."""
forecaster = DummyDeepForecaster(window=10)

# check default parameters
assert forecaster.horizon == 1
assert forecaster.window == 10
assert forecaster.verbose == 0
assert forecaster.callbacks is None
assert forecaster.axis == 0
assert forecaster.last_file_name == "last_model"
assert forecaster.file_path == "./"

# check default attributes after init
assert forecaster.model_ is None
assert forecaster.history_ is None
assert forecaster.last_window_ is None

# check tags
tags = forecaster.get_tags()
assert tags["algorithm_type"] == "deeplearning"
assert tags["capability:horizon"]
assert tags["capability:univariate"]
Loading