Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
d381d5e
arima first
TonyBagnall May 24, 2025
3a0552b
move utils
TonyBagnall May 24, 2025
0ac5380
make functions private
TonyBagnall May 24, 2025
44b36a7
Modularise SARIMA model
May 28, 2025
6d18de9
Add ARIMA forecaster to forecasting package
May 28, 2025
b7e6424
Add example to ARIMA forecaster, this also tests the forecaster is pr…
May 28, 2025
e33fa4d
Basic ARIMA model
May 28, 2025
f613f7e
Convert ARIMA to numba version
May 28, 2025
a6b708c
Merge branch 'main' into arb/base_arima
alexbanwell1 May 28, 2025
24ab433
Add Auto ARIMA starting point
May 28, 2025
5060928
Merge branch 'arb/base_arima' into arb/auto_arima
May 28, 2025
9eb00f6
Adjust parameters to allow modification in fit
May 28, 2025
9ef70fa
Merge branch 'arb/base_arima' into arb/auto_arima
May 28, 2025
f0c0443
Non-seasonal AutoARIMA Forecaster
May 28, 2025
5f2d80f
Numbafy AutoARIMA code
May 28, 2025
d4ed4b1
Update example and return native python type
May 28, 2025
0ecca96
Merge branch 'arb/base_arima' into arb/auto_arima
May 28, 2025
2893e1b
Fix examples for tests
May 28, 2025
c83052b
Modify AutoARIMA function to take the model function as a parameter
May 28, 2025
9801e8b
Fix Nelder-Mead Optimisation Algorithm Example
May 28, 2025
94e9080
Merge branch 'arb/base_arima' into arb/auto_arima
May 28, 2025
2f928c7
Fix Nelder-Mead Optimisation Algorithm Example #2
May 28, 2025
5c0ae94
Merge branch 'arb/base_arima' into arb/auto_arima
May 28, 2025
94cd5b3
Remove Nelder-Mead Example due to issues with numba caching functions
May 28, 2025
a9a75dd
Merge branch 'arb/base_arima' into arb/auto_arima
May 28, 2025
0d0d63f
Fix return type issue
May 28, 2025
628da30
Merge branch 'arb/base_arima' into arb/auto_arima
May 28, 2025
39a3ed2
Address PR Feedback
May 28, 2025
05a2785
Ignore small tolerances in floating point value in output of example
May 28, 2025
fd3c846
Merge branch 'arb/base_arima' into arb/auto_arima
May 28, 2025
73966ab
Fix kpss_test example
May 28, 2025
d00c3fe
Merge branch 'arb/base_arima' into arb/auto_arima
May 28, 2025
a0f090d
Fix kpss_test example #2
May 28, 2025
a398967
Merge branch 'arb/base_arima' into arb/auto_arima
May 28, 2025
6884703
Update documentation for ARIMAForecaster, change constant_term to be …
Jun 2, 2025
e445d83
Merge branch 'arb/base_arima' into arb/auto_arima
Jun 2, 2025
93b3df8
Convert constant term to bool, add type hints
Jun 2, 2025
02a9c49
Add type hints
Jun 2, 2025
44a8647
Merge branch 'main' into arb/base_arima
alexbanwell1 Jun 2, 2025
1844225
Merge branch 'main' into arb/auto_arima
alexbanwell1 Jun 2, 2025
9af3a56
Modify ARIMA to allow predicting multiple values by updating the stat…
Jun 8, 2025
1456d1f
Merge branch 'arb/base_arima' into arb/auto_arima
Jun 9, 2025
4c63af5
Merge branch 'main' into arb/base_arima
TonyBagnall Jun 9, 2025
e898f2f
Fix bug using self.d rather than self.d_
Jun 9, 2025
11c4987
Merge branch 'arb/base_arima' of https://github.com/aeon-toolkit/aeon…
Jun 9, 2025
16af9e9
Merge branch 'arb/base_arima' into arb/auto_arima
Jun 9, 2025
c0daa74
Update AutoARIMA to allow predicting multiple values without refittin…
Jun 10, 2025
ef6c5ad
Merge branch 'main' into arb/auto_arima
Aug 12, 2025
dfca539
Reorganise files to match current directory structure
Aug 12, 2025
45cf6bb
Complete Reorganisation and rename AutoARIMAForecaster to AutoARIMA
Aug 12, 2025
5f4105f
Add auto_arima wrapping basic ARIMA model
Aug 14, 2025
2b32514
First pass AutoETS
Aug 14, 2025
eddafad
Merge branch 'main' into arb/auto_ets
alexbanwell1 Aug 14, 2025
8de64ac
Merge branch 'main' into arb/auto_ets
alexbanwell1 Sep 23, 2025
f819d9d
Merge branch 'main' into arb/auto_ets
alexbanwell1 Sep 23, 2025
24c9a82
Move auto to main ets file
alexbanwell1 Sep 23, 2025
91c56b6
Attempt to fix ZDE by not testing lags > data length
alexbanwell1 Sep 24, 2025
c1fb681
Fix #2 to prevent ZDE - use min function not max function!
alexbanwell1 Sep 24, 2025
05a4405
Only test lags up to len(data)-1
alexbanwell1 Sep 24, 2025
ed8174f
Fix doctests and hangovers
alexbanwell1 Sep 24, 2025
4211721
Correct forecast method
alexbanwell1 Sep 24, 2025
3c694f1
Make _forecast more efficient
alexbanwell1 Sep 24, 2025
c8a1eb8
Add AutoETS tests
alexbanwell1 Sep 24, 2025
6f7f833
Fix test issue caused by linting fixes
alexbanwell1 Sep 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aeon/forecasting/stats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Stats based forecasters."""

__all__ = [
"AutoETS",
"ARIMA",
"AutoARIMA",
"AutoTAR",
Expand All @@ -11,7 +12,7 @@
]

from aeon.forecasting.stats._arima import ARIMA, AutoARIMA
from aeon.forecasting.stats._ets import ETS
from aeon.forecasting.stats._ets import ETS, AutoETS
from aeon.forecasting.stats._tar import TAR, AutoTAR
from aeon.forecasting.stats._theta import Theta
from aeon.forecasting.stats._tvp import TVP
150 changes: 148 additions & 2 deletions aeon/forecasting/stats/_ets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""ETS class.
"""ETS and AutoETS class.

An implementation of the exponential smoothing statistics forecasting algorithm.
Implements additive and multiplicative error models. We recommend using the AutoETS
version, but this is useful for demonstrations.
"""

__maintainer__ = []
__all__ = ["ETS"]
__all__ = ["ETS", "AutoETS"]


import numpy as np
Expand All @@ -20,6 +20,7 @@
_ets_predict_value,
)
from aeon.forecasting.utils._nelder_mead import nelder_mead
from aeon.forecasting.utils._seasonality import calc_seasonal_period

ADDITIVE = "additive"
MULTIPLICATIVE = "multiplicative"
Expand Down Expand Up @@ -271,6 +272,111 @@ def iterative_forecast(self, y, prediction_horizon):
return preds


class AutoETS(BaseForecaster):
"""Automatic Exponential Smoothing forecaster.

An implementation of the exponential smoothing statistics forecasting algorithm.
Chooses betweek additive and multiplicative error models,
None, additive and multiplicative (including damped) trend and
None, additive and multiplicative seasonality[1]_.

Parameters
----------
horizon : int, default = 1
The horizon to forecast to.

References
----------
.. [1] R. J. Hyndman and G. Athanasopoulos,
Forecasting: Principles and Practice. Melbourne, Australia: OTexts, 2014.

Examples
--------
>>> from aeon.forecasting.stats import AutoETS
>>> from aeon.datasets import load_airline
>>> y = load_airline()
>>> forecaster = AutoETS()
>>> forecaster.forecast(y)
435.9312382780535
"""

_tags = {
"capability:horizon": False,
}

def __init__(self):
self.error_type_ = 0
self.trend_type_ = 0
self.seasonality_type_ = 0
self.seasonal_period_ = 0
self.wrapped_model_ = None
super().__init__(horizon=1, axis=1)

def _fit(self, y, exog=None):
"""Fit Auto Exponential Smoothing forecaster to series y.

Fit a forecaster to predict self.horizon steps ahead using y.

Parameters
----------
y : np.ndarray
A time series on which to learn a forecaster to predict horizon ahead
exog : np.ndarray, default =None
Optional exogenous time series data assumed to be aligned with y

Returns
-------
self
Fitted AutoETS.
"""
data = y.squeeze()
(
self.error_type_,
self.trend_type_,
self.seasonality_type_,
self.seasonal_period_,
) = auto_ets(data)
self.wrapped_model_ = ETS(
self.error_type_,
self.trend_type_,
self.seasonality_type_,
self.seasonal_period_,
)
self.wrapped_model_.fit(y, exog)
return self

def _predict(self, y=None, exog=None):
"""
Predict the next horizon steps ahead.

Parameters
----------
y : np.ndarray, default = None
A time series to predict the next horizon value for. If None,
predict the next horizon value after series seen in fit.
exog : np.ndarray, default =None
Optional exogenous time series data assumed to be aligned with y

Returns
-------
float
single prediction self.horizon steps ahead of y.
"""
return self.wrapped_model_.predict(y, exog)

def _forecast(self, y, exog=None, axis=1):
self.fit(y, exog=exog)
return float(self.wrapped_model_.forecast_)

def iterative_forecast(self, y, prediction_horizon):
"""Forecast with ETS specific iterative method.

Overrides the base class iterative_forecast to avoid refitting on each step.
This simply rolls the ETS model forward
"""
return self.wrapped_model_.iterative_forecast(y, prediction_horizon)


@njit(fastmath=True, cache=True)
def _numba_predict(
trend_type,
Expand Down Expand Up @@ -320,3 +426,43 @@ def _validate_parameter(var, can_be_none):
f"variable must be either string or integer with values"
f" {valid_str} or {valid_int} but saw {var}"
)


def auto_ets(data):
"""Calculate model parameters based on the internal nelder-mead implementation."""
seasonal_period = calc_seasonal_period(data)
lowest_aic = -1
best_model = None
for error_type in range(1, 3):
for trend_type in range(0, 3):
for seasonality_type in range(0, 2 * (seasonal_period != 1) + 1):
model_seasonal_period = seasonal_period
if seasonal_period < 1 or seasonality_type == 0:
model_seasonal_period = 1
model = np.array(
[
error_type,
trend_type,
seasonality_type,
model_seasonal_period,
],
dtype=np.int32,
)
try:
(_, aic) = nelder_mead(
1,
1 + 2 * (trend_type != 0) + (seasonality_type != 0),
data,
model,
)
except ZeroDivisionError:
continue
if lowest_aic == -1 or lowest_aic > aic:
lowest_aic = aic
best_model = (
error_type,
trend_type,
seasonality_type,
model_seasonal_period,
)
return best_model
114 changes: 113 additions & 1 deletion aeon/forecasting/stats/tests/test_ets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest

from aeon.forecasting.stats._ets import ETS, _validate_parameter
from aeon.forecasting.stats._ets import ETS, AutoETS, _validate_parameter


@pytest.mark.parametrize(
Expand Down Expand Up @@ -105,3 +105,115 @@ def test_ets_iterative_forecast():
forecaster = ETS(trend_type=None)
forecaster._fit(y)
assert forecaster._trend_type == 0


# small seasonal-ish series (same as in ETS tests)
Y_SEASONAL = np.array(
[3, 10, 12, 13, 12, 10, 12, 3, 10, 12, 13, 12, 10, 12], dtype=float
)
# another shortish series for basic sanity checks
Y_SHORT = np.array([10, 12, 14, 13, 15, 16, 18, 19, 20, 21, 22, 23], dtype=float)


def test_autoets_fit_sets_attributes_and_wraps():
"""Fit should set type/period attributes and wrap an ETS instance."""
forecaster = AutoETS()
forecaster.fit(Y_SEASONAL)

# wrapped model exists and is ETS
assert forecaster.wrapped_model_ is not None
assert isinstance(forecaster.wrapped_model_, ETS)

# discovered structure attributes should exist and be integers >= 0
for attr in ("error_type_", "trend_type_", "seasonality_type_", "seasonal_period_"):
val = getattr(forecaster, attr)
assert isinstance(val, (int, np.integer))
assert val >= 0

# wrapped model should have been fitted and expose a finite forecast_
assert hasattr(forecaster.wrapped_model_, "forecast_")
assert np.isfinite(forecaster.wrapped_model_.forecast_)


def test_autoets_predict_returns_finite_float():
"""_predict should return a finite float once fitted."""
forecaster = AutoETS()
forecaster.fit(Y_SHORT)
pred = forecaster._predict(Y_SHORT)
assert isinstance(pred, float)
assert np.isfinite(pred)


def test_autoets_forecast_sets_wrapped_and_returns_forecast_float():
"""_forecast should fit internally, set wrapped forecast_, and return that value."""
forecaster = AutoETS()
f = forecaster._forecast(Y_SEASONAL)
assert isinstance(f, float)
assert np.isfinite(f)
assert forecaster.wrapped_model_ is not None
assert hasattr(forecaster.wrapped_model_, "forecast_")
assert np.isclose(f, float(forecaster.wrapped_model_.forecast_))


def test_autoets_iterative_forecast_shape_and_validity():
"""iterative_forecast should delegate to wrapped ETS and return valid outputs."""
h = 5
forecaster = AutoETS()
forecaster.fit(Y_SHORT)
preds = forecaster.iterative_forecast(Y_SHORT, prediction_horizon=h)

assert isinstance(preds, np.ndarray)
assert preds.shape == (h,)
assert np.all(np.isfinite(preds))

# Optional: first iterative step should match one-step-ahead forecast after fit
assert np.isclose(preds[0], forecaster.wrapped_model_.forecast_, atol=1e-6)


def test_autoets_horizon_greater_than_one_raises():
"""
AutoETS.fit should raise ValueError.

when horizon > 1 (ETS only supports 1-step fit).
"""
forecaster = AutoETS()
forecaster.horizon = 2
with pytest.raises(ValueError, match="Horizon is set >1"):
forecaster.fit(Y_SEASONAL)


def test_autoets_predict_matches_wrapped_predict():
"""_predict should match the wrapped ETS model's predict."""
forecaster = AutoETS()
forecaster.fit(Y_SEASONAL)
a = forecaster._predict(Y_SEASONAL)
b = forecaster.wrapped_model_.predict(Y_SEASONAL)
assert isinstance(a, float) and isinstance(b, float)
assert np.isfinite(a) and np.isfinite(b)
assert np.isclose(a, b)


def test_autoets_forecast_is_consistent_with_wrapped():
"""_forecast should equal the wrapped model's forecast after internal fit."""
forecaster = AutoETS()
val = forecaster._forecast(Y_SHORT)
assert np.isclose(val, float(forecaster.wrapped_model_.forecast_))


def test_autoets_exog_raises():
"""AutoETS.fit should raise ValueError when exog passed."""
forecaster = AutoETS()
exog = np.arange(len(Y_SEASONAL), dtype=float) # simple aligned exogenous regressor
with pytest.raises(
ValueError,
match="AutoETS cannot handle exogenous variables",
):
forecaster.fit(Y_SEASONAL, exog=exog)


def test_autoets_repeatability_on_same_input():
"""Forecasting twice on the same series should be deterministic."""
forecaster = AutoETS()
f1 = forecaster._forecast(Y_SEASONAL)
f2 = forecaster._forecast(Y_SEASONAL)
assert np.isclose(f1, f2)
2 changes: 1 addition & 1 deletion aeon/forecasting/utils/_seasonality.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def calc_seasonal_period(data):
The estimated seasonal period (lag) of the series. Returns 1 if no significant
peak is detected in the autocorrelation.
"""
lags = acf(data, 24)
lags = acf(data, min(24, len(data) - 1))
lags = np.concatenate((np.array([1.0]), lags))
peaks = []
mean_lags = np.mean(lags)
Expand Down