Skip to content

[ENH] fit_is_empty ETS #2895

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 112 additions & 161 deletions aeon/forecasting/stats/_ets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,10 @@
from aeon.forecasting.utils._extract_paras import _extract_ets_params
from aeon.forecasting.utils._loss_functions import (
_ets_fit,
_ets_initialise,
_ets_predict_value,
)
from aeon.forecasting.utils._nelder_mead import nelder_mead

ADDITIVE = "additive"
MULTIPLICATIVE = "multiplicative"


class ETS(BaseForecaster, IterativeForecastingMixin):
"""Exponential Smoothing (ETS) forecaster.
Expand All @@ -44,37 +40,8 @@ class ETS(BaseForecaster, IterativeForecastingMixin):
Type of seasonal component: None (0), `additive' (1) or 'multiplicative' (2)
seasonal_period : int, default=1
Number of time points in a seasonal cycle.
alpha : float, default=0.1
Level smoothing parameter.
beta : float, default=0.01
Trend smoothing parameter.
gamma : float, default=0.01
Seasonal smoothing parameter.
phi : float, default=0.99
Trend damping parameter (used only for damped trend models).

Attributes
----------
forecast_val_ : float
Forecast value for the given horizon.
level_ : float
Estimated level component.
trend_ : float
Estimated trend component.
seasonality_ : array-like or None
Estimated seasonal components.
aic_ : float
Akaike Information Criterion of the fitted model.
avg_mean_sq_err_ : float
Average mean squared error of the fitted model.
residuals_ : list of float
Residuals from the fitted model.
fitted_values_ : list of float
Fitted values for the training data.
liklihood_ : float
Log-likelihood of the fitted model.
n_timepoints_ : int
Number of time points in the training series.
iterations : int, default=200
Number of iterations for the Nelder-Mead optimisation algorithm used to fit.

References
----------
Expand All @@ -97,6 +64,7 @@ class ETS(BaseForecaster, IterativeForecastingMixin):

_tags = {
"capability:horizon": False,
"fit_is_empty": True,
}

def __init__(
Expand All @@ -107,48 +75,88 @@ def __init__(
seasonal_period: int = 1,
iterations: int = 200,
):
self.forecast_val_ = 0.0
self.level_ = 0.0
self.trend_ = 0.0
self.seasonality_ = None
self.error_type = error_type
self.trend_type = trend_type
self.seasonality_type = seasonality_type
self.seasonal_period = seasonal_period
self.iterations = iterations
self.n_timepoints_ = 0
self.avg_mean_sq_err_ = 0
self.liklihood_ = 0
self.k_ = 0
self.aic_ = 0
self.residuals_ = []
self.fitted_values_ = []
self._model = []
self.parameters_ = []
self.alpha_ = 0
self.beta_ = 0
self.gamma_ = 0
self.phi_ = 0
self.forecast_ = 0
super().__init__(horizon=1, axis=1)

def _fit(self, y, exog=None):
"""Fit Exponential Smoothing forecaster to series y.
super().__init__(horizon=1, axis=1)

Fit a forecaster to predict self.horizon steps ahead using y.
def _predict(self, y=None, exog=None):
"""
Predict the next horizon steps ahead.

Parameters
----------
y : np.ndarray
A time series on which to learn a forecaster to predict horizon ahead
y : np.ndarray, default = None
A time series to predict the next horizon value for. If None,
predict the next horizon value after series seen in fit.
exog : np.ndarray, default =None
Optional exogenous time series data assumed to be aligned with y

Returns
-------
self
Fitted ETS.
float
single prediction self.horizon steps ahead of y.
"""
(
trend_type,
seasonality_type,
seasonal_period,
level,
trend,
seasonality,
n_timepoints,
phi,
) = self._shared_fit(y)

fitted_value = _numba_predict(
trend_type,
seasonality_type,
level,
trend,
seasonality,
phi,
self.horizon,
n_timepoints,
seasonal_period,
)
return fitted_value

def iterative_forecast(self, y, prediction_horizon):
"""Forecast with ETS specific iterative method.

Overrides the base class iterative_forecast to avoid refitting on each step.
This simply rolls the ETS model forward
"""
(
trend_type,
seasonality_type,
seasonal_period,
level,
trend,
seasonality,
n_timepoints,
phi,
) = self._shared_fit(y)

preds = np.zeros(prediction_horizon)
for i in range(0, prediction_horizon):
preds[i] = _numba_predict(
trend_type,
seasonality_type,
level,
trend,
seasonality,
phi,
i + 1,
n_timepoints,
seasonal_period,
)
return preds

def _shared_fit(self, y):
_validate_parameter(self.error_type, False)
_validate_parameter(self.seasonality_type, True)
_validate_parameter(self.trend_type, True)
Expand All @@ -157,120 +165,63 @@ def _fit(self, y, exog=None):
def _get_int(x):
if x is None:
return 0
if x == ADDITIVE:
if x == "additive":
return 1
if x == MULTIPLICATIVE:
if x == "multiplicative":
return 2
return x

self._error_type = _get_int(self.error_type)
self._seasonality_type = _get_int(self.seasonality_type)
self._trend_type = _get_int(self.trend_type)
self._seasonal_period = self.seasonal_period
if self._seasonal_period < 1 or self._seasonality_type == 0:
self._seasonal_period = 1
self._model = np.array(
error_type = _get_int(self.error_type)
seasonality_type = _get_int(self.seasonality_type)
trend_type = _get_int(self.trend_type)
seasonal_period = self.seasonal_period
if seasonal_period < 1 or seasonality_type == 0:
seasonal_period = 1

model = np.array(
[
self._error_type,
self._trend_type,
self._seasonality_type,
self._seasonal_period,
error_type,
trend_type,
seasonality_type,
seasonal_period,
],
dtype=np.int32,
)
data = y.squeeze()
(self.parameters_, self.aic_) = nelder_mead(

parameters, aic = nelder_mead(
1,
1 + 2 * (self._trend_type != 0) + (self._seasonality_type != 0),
1 + 2 * (trend_type != 0) + (seasonality_type != 0),
data,
self._model,
model,
max_iter=self.iterations,
)
self.alpha_, self.beta_, self.gamma_, self.phi_ = _extract_ets_params(
self.parameters_, self._model
)
(
self.aic_,
self.level_,
self.trend_,
self.seasonality_,
self.n_timepoints_,
self.residuals_,
self.fitted_values_,
self.avg_mean_sq_err_,
self.liklihood_,
self.k_,
) = _ets_fit(self.parameters_, data, self._model)
self.forecast_ = _numba_predict(
self._trend_type,
self._seasonality_type,
self.level_,
self.trend_,
self.seasonality_,
self.phi_,
self.horizon,
self.n_timepoints_,
self._seasonal_period,
)

return self
alpha, beta, gamma, phi = _extract_ets_params(parameters, model)

def _predict(self, y, exog=None):
"""
Predict the next horizon steps ahead.

Parameters
----------
y : np.ndarray, default = None
A time series to predict the next horizon value for. If None,
predict the next horizon value after series seen in fit.
exog : np.ndarray, default =None
Optional exogenous time series data assumed to be aligned with y

Returns
-------
float
single prediction self.horizon steps ahead of y.
"""
return self.forecast_

def _initialise(self, data):
"""
Initialize level, trend, and seasonality values for the ETS model.

Parameters
----------
data : array-like
The time series data
(should contain at least two full seasons if seasonality is specified)
"""
self.level_, self.trend_, self.seasonality_ = _ets_initialise(
self._trend_type, self._seasonality_type, self._seasonal_period, data
(
aic,
level,
trend,
seasonality,
n_timepoints,
residuals,
fitted_values,
avg_mean_sq_err,
likelihood,
k,
) = _ets_fit(parameters, data, model)

return (
trend_type,
seasonality_type,
seasonal_period,
level,
trend,
seasonality,
n_timepoints,
phi,
)

def iterative_forecast(self, y, prediction_horizon):
"""Forecast with ETS specific iterative method.

Overrides the base class iterative_forecast to avoid refitting on each step.
This simply rolls the ETS model forward
"""
self.fit(y)
preds = np.zeros(prediction_horizon)
preds[0] = self.forecast_
for i in range(1, prediction_horizon):
preds[i] = _numba_predict(
self._trend_type,
self._seasonality_type,
self.level_,
self.trend_,
self.seasonality_,
self.phi_,
i + 1,
self.n_timepoints_,
self._seasonal_period,
)
return preds


@njit(fastmath=True, cache=True)
def _numba_predict(
Expand Down Expand Up @@ -302,10 +253,10 @@ def _numba_predict(


def _validate_parameter(var, can_be_none):
valid_str = (ADDITIVE, MULTIPLICATIVE)
valid_str = ("additive", "multiplicative")
valid_int = (1, 2)
if can_be_none:
valid_str = (None, ADDITIVE, MULTIPLICATIVE)
valid_str = (None, "additive", "multiplicative")
valid_int = (0, 1, 2)
valid = True
if isinstance(var, str) or var is None:
Expand Down
11 changes: 3 additions & 8 deletions aeon/forecasting/stats/tests/test_ets.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_ets_raises_on_horizon_greater_than_one():
forecaster.horizon = 2
data = np.array([3, 10, 12, 13, 12, 10, 12, 3, 10, 12, 13, 12, 10, 12])
with pytest.raises(ValueError, match="Horizon is set >1, but"):
forecaster.fit(data)
forecaster.predict(data)


def test_ets_iterative_forecast():
Expand All @@ -98,10 +98,5 @@ def test_ets_iterative_forecast():
assert np.all(np.isfinite(preds)), "All forecast values should be finite"

# Optional: check that the first prediction equals forecast_ from .fit()
forecaster.fit(y)
assert np.isclose(
preds[0], forecaster.forecast_, atol=1e-6
), "First forecast should match forecast_"
forecaster = ETS(trend_type=None)
forecaster._fit(y)
assert forecaster._trend_type == 0
p = forecaster.predict(y)
assert np.isclose(preds[0], p, atol=1e-6), "First forecast should match predict"
1 change: 0 additions & 1 deletion aeon/forecasting/utils/_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def _ets_update_states(
@njit(fastmath=True, cache=True)
def _ets_predict_value(trend_type, seasonality_type, level, trend, seasonality, phi):
"""

Generate various useful values, including the next fitted value.

Parameters
Expand Down
Loading