Skip to content
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
### For developers of the library:

- Fixed all warnings when generating the documentation. Also, the documentation can now be generated without having to re-install Darts before every run. [#2936](https://github.com/unit8co/darts/pull/2936) by [Dennis Bader](https://github.com/dennisbader).
- Reworked the `_validate_model_params` function of `TorchForecastingModel` to support more complicated cases of class inheritance. [#2908](https://github.com/unit8co/darts/pull/2908) by [Tim Rosenflanz](https://github.com/trosenflanz).


## [0.38.0](https://github.com/unit8co/darts/tree/0.38.0) (2025-10-03)

Expand Down
20 changes: 14 additions & 6 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,14 +363,22 @@ def encode_year(idx):

@classmethod
def _validate_model_params(cls, **kwargs):
"""validate that parameters used at model creation are part of :class:`TorchForecastingModel`,
:class:`PLForecastingModule` or cls __init__ methods.
"""validate that parameters used at model creation are part of the model cls __init__,
its parents __init__ methods, or :class:`PLForecastingModule`
"""
valid_kwargs = (
set(inspect.signature(TorchForecastingModel.__init__).parameters.keys())
| set(inspect.signature(PLForecastingModule.__init__).parameters.keys())
| set(inspect.signature(cls.__init__).parameters.keys())
# initiate with PLForecastingModule params that isn't part of the base class
valid_kwargs = set(
inspect.signature(PLForecastingModule.__init__).parameters.keys()
)
# add params from the full list of base classes
for base in inspect.getmro(cls):
if base is object:
break
sig = inspect.signature(base.__init__)
valid_kwargs.update(sig.parameters.keys())
# Remove 'self','args,'kwargs' from consideration
for generic_arg in ["self", "args", "kwargs"]:
valid_kwargs.discard(generic_arg)

invalid_kwargs = [kwarg for kwarg in kwargs if kwarg not in valid_kwargs]

Expand Down
20 changes: 20 additions & 0 deletions darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,6 +1488,26 @@ def test_wrong_model_creation_params(self):
with pytest.raises(ValueError):
_ = RNNModel(12, "RNN", 10, 10, **invalid_kwarg)

def test_inherited_wrong_model_creation_params(self):
# test using inheritance class
class RnnModelLambda(RNNModel):
def __init__(self, positional_param, named_param=0, *args, **kwargs):
super().__init__(*args, **kwargs)

valid_kwargs = {
"pl_trainer_kwargs": {},
"named_param": 1,
"positional_param": 1,
}
invalid_kwargs = {"some_invalid_kwarg": None}

# valid params should not raise an error
_ = RnnModelLambda(0, input_chunk_length=12, **valid_kwargs)

# invalid params should raise an error
with pytest.raises(ValueError):
_ = RnnModelLambda(0, input_chunk_length=12, **invalid_kwargs)

def test_metrics(self):
metric = MeanAbsolutePercentageError()
metric_collection = MetricCollection([
Expand Down