diff --git a/docs/source/tutorials/ptf_V2_example.ipynb b/docs/source/tutorials/ptf_V2_example.ipynb index 81313151d..1419b8360 100644 --- a/docs/source/tutorials/ptf_V2_example.ipynb +++ b/docs/source/tutorials/ptf_V2_example.ipynb @@ -1,813 +1,1775 @@ { - "cells": [ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "rzVbXsEBxnF-" + }, + "source": [ + "# Example Notebook for a basic vignette for `pytorch-forecasting v2` Model Training and Inference" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yt0uZV7Px-40" + }, + "source": [ + "
\n", + ":warning: The \"Data Pipeline\" showcased here is part of an experimental rework of the `pytorch-forecasting` data layer, planned for release in v2.0.0. The API is currently unstable and subject to change without prior notice. This notebook serves as a basic demonstration of the intended workflow and is not recommended for use in production environments. Feedback and suggestions are highly encouraged β€” please share them in issue 1736.\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "r15UunnLoxnK" + }, + "source": [ + "In this notebook, we demonstrate how to train and evaluate the **Temporal Fusion Transformer (TFT)** using the new `TimeSeries` and `DataModule` API from the v2 pipeline.\n", + "We can do this in 2 ways:\n", + "1. **High-level package API:**\n", + "\n", + " This approach handles data loading, dataloader creation, and model training internally. It provides a simple, `scikit-learn`-like `fit` β†’ `predict` workflow.\n", + " Users can still configure key training options (such as the `trainer`, callbacks, and training parameters) but cannot plug in fully custom `trainer` implementations or override internal pipeline logic.\n", + "\n", + "2. **Low-level 3-stage pipeline**:\n", + "This involves explicitly constructing:\n", + " * a `TimeSeries` object\n", + "\n", + " * a `DataModule`\n", + "\n", + " * the model (e.g., `TFT`)\n", + " \n", + " This workflow is ideal if you need custom setups such as custom trainers, callbacks, or advanced data preprocessing.\n", + " It requires a deeper understanding of how the three layers (TimeSeries, DataModule, and the model) interact, but offers maximum flexibility." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QyMFNk4MyY_b" + }, + "source": [ + "# Create Synthetic data\n", + "We generate a synthetic dataset using `load_toydata` that creates a `pandas` DataFrame with just numerical values as for now **the pipeline assumes the data to be numerical only**." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "RkgOT4kiy_RU" + }, + "outputs": [], + "source": [ + "from pytorch_forecasting.data.examples import load_toydata" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "WX-FRdusJSVN", + "outputId": "2ad916b8-2fd9-4318-afb1-2bda84d284d7" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "rzVbXsEBxnF-" + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6712252870750063,\n \"min\": -1.2780952045426857,\n \"max\": 1.3163602917006327,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.19335967827533446,\n 0.8492207493147326,\n -0.9687640491099185\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6753351884449413,\n \"min\": -1.2780952045426857,\n \"max\": 1.3163602917006327,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.6981263626070341,\n 0.7052787051636003,\n -0.861386757323439\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2792423704109133,\n \"min\": 0.031153133884698536,\n \"max\": 0.9662188410416612,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.24602577096925082,\n 0.8680231736929984,\n 0.6913124004679789\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", + "type": "dataframe", + "variable_name": "data_df" }, - "source": [ - "# `pytorch-forecasting v2` Model Training and Inference - Beta API" + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
000-0.0306430.14828001.0000000.0392130
1010.1482800.43302900.9950040.0392130
2020.4330290.74251100.9800670.0392130
3030.7425110.72927000.9553360.0392130
4040.7292700.62860400.9210610.0392130
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "text/plain": [ + " series_id time_idx x y category future_known_feature \\\n", + "0 0 0 -0.030643 0.148280 0 1.000000 \n", + "1 0 1 0.148280 0.433029 0 0.995004 \n", + "2 0 2 0.433029 0.742511 0 0.980067 \n", + "3 0 3 0.742511 0.729270 0 0.955336 \n", + "4 0 4 0.729270 0.628604 0 0.921061 \n", + "\n", + " static_feature static_feature_cat \n", + "0 0.039213 0 \n", + "1 0.039213 0 \n", + "2 0.039213 0 \n", + "3 0.039213 0 \n", + "4 0.039213 0 " ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "num_series = 100 # Number of individual time series to generate\n", + "seq_length = 50 # Length of each time series\n", + "data_df = load_toydata(num_series, seq_length)\n", + "data_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_8TgLH82runO" + }, + "source": [ + "# High-level API\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "A1cqKCRur4oj" + }, + "source": [ + "## Steps\n", + "* Create the `TimeSeries` object\n", + "* Create `configs` for model, `datamodule`, `trainer` etc.\n", + "* Create the `model_pkg` object\n", + "* perform `pkg.fit` and `pkg.predict`.\n", + "\n", + "## Create Dataset object\n", + "\n", + "`TimeSeries` returns the raw data in terms of tensors .\n", + "\n", + "---\n", + "\n", + "`TimeSeries` dataset's Key arguments:\n", + "- `data`: DataFrame with sequence data.\n", + "- `time`: integer typed column denoting the time index within `data`.\n", + "- `target`: Column(s) in `data` denoting the forecasting target.\n", + "- `group`: List of column names identifying a time series instance within `data`.\n", + "- `num`: List of numerical features.\n", + "- `cat`: List of categorical features.\n", + "- `known`: Features known in future\n", + "- `unknown`: Features not known in the future\n", + "- `static`: List of variables that do not change over time,\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "u8OPR0HntXqR" + }, + "outputs": [], + "source": [ + "from pytorch_forecasting.data.timeseries import TimeSeries" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "6a_oy4VjtrHQ", + "outputId": "54678fb8-864e-4f32-eeb9-83697946a3e5" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "yt0uZV7Px-40" - }, - "source": [ - "
\n", - ":warning: The vignette showcased here is part of an experimental rework of the `pytorch-forecasting` data layer, planned for release in v2.0.0. The API is currently unstable and subject to change without prior notice.\n", - "\n", - "Feedback and suggestions are highly encouraged β€” please share them in issue 1736.\n", - "
\n" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "/content/pytorch-forecasting/pytorch_forecasting/data/timeseries/_timeseries_v2.py:105: UserWarning: TimeSeries is part of an experimental rework of the pytorch-forecasting data layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. For beta testing, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", + " warn(\n" + ] + } + ], + "source": [ + "# create `TimeSeries` dataset that returns the raw data in terms of tensors\n", + "dataset = TimeSeries(\n", + " data=data_df,\n", + " time=\"time_idx\",\n", + " target=\"y\",\n", + " group=[\"series_id\"],\n", + " num=[\"x\", \"future_known_feature\", \"static_feature\"],\n", + " cat=[\"category\", \"static_feature_cat\"],\n", + " known=[\"future_known_feature\"],\n", + " unknown=[\"x\", \"category\"],\n", + " static=[\"static_feature\", \"static_feature_cat\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EoS6W9zh6wCj" + }, + "source": [ + "## Create the configs\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "MKPXPUcC5dTY" + }, + "outputs": [], + "source": [ + "from sklearn.preprocessing import StandardScaler\n", + "from pytorch_forecasting.data.encoders import (\n", + " EncoderNormalizer,\n", + " NaNLabelEncoder,\n", + " TorchNormalizer,\n", + ")\n", + "from pytorch_forecasting.metrics import MAE, SMAPE" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WYl9-oZz6nk6" + }, + "source": [ + "Here we use `EncoderDecoderTimeSeriesDataModule`\n", + "\n", + "\n", + "`EncoderDecoderTimeSeriesDataModule` key arguments:\n", + "- `time_series_dataset`: `TimeSeries` dataset instance\n", + "- `max_encoder_length` : Maximum length of the encoder input sequence.\n", + "- `max_prediction_length` : Maximum length of the decoder output sequence.\n", + "- `batch_size` : Batch size for DataLoader.\n", + "- `categorical_encoders` : Dictionary of categorical encoders.\n", + "- `scalers` : Dictionary of feature scalers.\n", + "- `target_normalizer`: Normalizer for the target variable." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "YGMShzfyttp_" + }, + "outputs": [], + "source": [ + "datamodule_cfg = dict(\n", + " max_encoder_length=30,\n", + " max_prediction_length=1,\n", + " batch_size=32,\n", + " categorical_encoders={\n", + " \"category\": NaNLabelEncoder(add_nan=True),\n", + " \"static_feature_cat\": NaNLabelEncoder(add_nan=True),\n", + " },\n", + " scalers={\n", + " \"x\": StandardScaler(),\n", + " \"future_known_feature\": StandardScaler(),\n", + " \"static_feature\": StandardScaler(),\n", + " },\n", + " target_normalizer=TorchNormalizer(),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pi5Qkznh6t3y" + }, + "source": [ + "We would use `TFT` model in this tutorial" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "q6Thm13ct7OV" + }, + "outputs": [], + "source": [ + "model_cfg = dict(\n", + " loss=MAE(),\n", + " logging_metrics=[MAE(), SMAPE()],\n", + " optimizer=\"adam\",\n", + " optimizer_params={\"lr\": 1e-3},\n", + " lr_scheduler=\"reduce_lr_on_plateau\",\n", + " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", + " hidden_size=64,\n", + " num_layers=2,\n", + " attention_head_size=4,\n", + " dropout=0.1,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "Stfuc_xCuON6" + }, + "outputs": [], + "source": [ + "trainer_cfg = dict(\n", + " max_epochs=5,\n", + " accelerator=\"auto\",\n", + " devices=1,\n", + " enable_progress_bar=True,\n", + " log_every_n_steps=10,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "id": "XS_ND8UAubdN" + }, + "outputs": [], + "source": [ + "from pytorch_forecasting.models.temporal_fusion_transformer._tft_pkg_v2 import (\n", + " TFT_pkg_v2,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6yoqI8907DG4" + }, + "source": [ + "## Create the `model_pkg` object\n", + "\n", + "This `pkg` class acts as a wrapper around the whole ML pipeline in `pytorch-forecasting` and we can simply just define the `pkg` class and then use `pkg.fit` and `pkg.predict` to perform the \"fit\", \"predict\" mechanisms." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "aOxng4Rguwj2", + "outputId": "2c50fcad-f990-4aae-f0bb-5dbdd6a87377" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "6D9ARyp05R0t" - }, - "source": [ - "In this vignette, we demonstrate how to train and evaluate the **Temporal Fusion Transformer (TFT)** using the new `TimeSeries` and `DataModule` API from the v2 pipeline.\n", - "\n", - "\n", - "## Steps\n", - "\n", - "1. **Load Data** \n", - "2. **Create Dataset & DataModule** \n", - "3. **Initialize, Train & Run Inference with the Model**\n", - "\n", - "\n", - "\n", - "### Load Data\n", - "\n", - "We generate a synthetic dataset using `load_toydata` which returns a `pandas` DataFrame with purely numerical values. \n", - "*(Note: The current pipeline assumes all inputs are numerical only.)*\n", - "\n", - "\n", - "\n", - "\n", - "### Create Dataset & DataModule\n", - "\n", - "- `TimeSeries` returns the raw data in terms of tensors .\n", - "- `DataModule` wraps the dataset, handles splits, preprocessing, batching, and exposes `metadata` for the model initialisation.\n", - "\n", - "\n", - "\n", - "### Initialize the Model\n", - "\n", - "We initialize the TFT model using the `metadata` provided by the `DataModule`. This metadata includes all required dimensional info for the encoder, decoder, and static inputs.\n", - "\n", - "\n", - "\n", - "### Train the Model\n", - "\n", - "We use a `Trainer` from PyTorch Lightning to train the model\n", - "\n", - "### Run Inference\n", - "\n", - "After training, we can make predictions using the trained model\n" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "{'loss': MAE(), 'logging_metrics': [MAE(), SMAPE()], 'optimizer': 'adam', 'optimizer_params': {'lr': 0.001}, 'lr_scheduler': 'reduce_lr_on_plateau', 'lr_scheduler_params': {'mode': 'min', 'factor': 0.1, 'patience': 10}, 'hidden_size': 64, 'num_layers': 2, 'attention_head_size': 4, 'dropout': 0.1}\n" + ] + } + ], + "source": [ + "model_pkg = TFT_pkg_v2(\n", + " model_cfg=model_cfg,\n", + " trainer_cfg=trainer_cfg,\n", + " datamodule_cfg=datamodule_cfg,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 976, + "referenced_widgets": [ + "4ecdea6764d145118ab53e59451d2b0c", + "e7b969aa6d8e433d9aeeac4357bc425d", + "42707b895305490b82cd644250e689fa", + "19c3c106d5a9489cae445a0b5fc88183", + "a6e5908902eb40e997e6086287f28f2a", + "f23d99cc4f01426eb4fc8d41fc8f4b16", + "df8d7458b0fa4f508d4ce357fc95c609", + "464464a47d604d708be37e18edec4810", + "c7ca9662eae04999b21de91c183a0856", + "e82943533ad54539a777d6adae271d0f", + "e0fa236745204d5ba2dbc1ff2c51f1a2", + "c52bb8ff12db4df3a05cf1da7b5470f7", + "273fb7ccddeb476f9c76bd1be44a6ae0", + "922dfb34c7494b20a874e294c07447e0", + "96eda6cd2fbc47b9a29d2cf176332058", + "d2ed70b544924436b185f79d0cd90862", + "aff8baef21494ccf99bf092fc3daaae0", + "640d3876c2ce49de9757962b5b5b0e32", + "f99327891eeb424b9df4fe04a6bedcfb", + "04804ca4c425464db2754abf3cf95568", + "386a2097c02f4af6ba239e385f4b0b47", + "7e95bdbc3a6c46bf8b1e96bdecfa7303", + "572629c64cfd46a983f1a8c6483a2cf1", + "61f356a200c5470db777e7e0f9e8c520", + "4905c6d809274aa39a984e3e458fc89a", + "b16522c88ebb435f96c05315ef91ebbb", + "65cce98698644285896363e509ae6139", + "a4a4c07f117e46989cb4877a5d2dd9e0", + "3f8cc20607db40c7acb4110feba9ab0d", + "3c5c1f55d5a64838b94fc0fbd85097c6", + "4f5cce37b6ac430e85757dd06b06953a", + "942645c506ea436fb598455d84c8a970", + "969a3ddfaed84150944d697307ababe4", + "955c5e9c139148a1a352d17202fe097f", + "f0fbcbcf02e443bc99a469cf4c7f8131", + "72fb23e179594f35a68418e2e0ee65fc", + "7808bf48e45940cfa0d4bccae784d730", + "ed0358a45ec14ce687fc02904a815e38", + "67a3d79f1b2e4e03b9a564286c04d5d4", + "27fd0da590314bb68dfac5b7c72d6584", + "15e539660a2547f49fb2cf8a6143f5fa", + "c46b831b37a347868f1d35d0dbbfd923", + "3f985da9d6a245c5b54dbb47926a4fd4", + "5ab64c01efb84e75af1a8aaf6675f5bf", + "8d0747756fd2434399ae8d233a82d607", + "f77d800d097b494ca3e945abdaedd75c", + "05a14444ea4043dea69a4e7185e66cb1", + "b775518f409449928c3211260d7223c0", + "2db9a1e74ad14139af235f1a2a146e0a", + "2bcdfdf1b12c495aa8b425c88fcfbd1b", + "d71a01b309e948239a16062097ee76a2", + "006cdf49ce55411bb072c2670d87773f", + "5b3be082628244948284a40bea451ff1", + "bf96e25bd5d64892a329b624961abeb8", + "b5e879fa1fee4d0ba30ac5af07d1d8c5", + "78f2d725dcb34deca5407c277c384d8d", + "e1dc997d76d54eb9a9245530a60c2cd9", + "8fa00a6415b74091a012bc5fff543f42", + "c74c472cdf174a28a4ca3fed1b312332", + "409d65e2b79f49318c580d9835ffc29c", + "568141e0b45f44ff9b497c6474d8019f", + "1a889d0b55bb4e6d80d5f170297a6262", + "679072fe36f5404588879eab670e01e2", + "ab72364dc8cf433e907c40df3e7be9e9", + "d56c627297bd435fbbc60317066084f9", + "6a62c54d7d7b4f689dc31e57aaa20411", + "3ab523d60fd249a7bcece32280872abd", + "ef8e78e4f8a248dcbbc2ea3e464c5922", + "56685ebbfd244154bd1829dec6f0db0b", + "dc5f9a923d27492cb382691ec01a1ddc", + "58995b1bd1c24433a3aca0ab53c6b8bf", + "9082a1b6eb3a4d14b4c92eedec1c2404", + "13eb0c5265ad48d08b3f8e46a55896f0", + "f91e202108684aa9af76fdf3e9d83206", + "f617590dcd184fd99f98515940ac85af", + "33a1e5f21b694e3bb62d2c0d73aa65e3", + "8c8c8832e16c4d489e0df7514dc78f6a" + ] }, + "id": "c27Qj4QAvFwx", + "outputId": "21bbd594-d92e-498b-bd02-71829295c483" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "QyMFNk4MyY_b" - }, - "source": [ - "# 1. Load Data\n", - "We generate a synthetic dataset using `load_toydata` that creates a `pandas` DataFrame with just numerical values as for now **the pipeline assumes the data to be numerical only**." - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "/content/pytorch-forecasting/pytorch_forecasting/data/data_module.py:129: UserWarning: EncoderDecoderTimeSeriesDataModule is part of an experimental rework of the pytorch-forecasting data layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. For beta testing, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", + " warn(\n", + "/content/pytorch-forecasting/pytorch_forecasting/models/base/_base_model_v2.py:64: UserWarning: The Model 'TFT' is part of an experimental reworkof the pytorch-forecasting model layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. This class is intended for beta testing and as a basic skeleton, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", + " warn(\n", + "INFO: GPU available: True (cuda), used: True\n", + "INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True\n", + "INFO: TPU available: False, using: 0 TPU cores\n", + "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "INFO: \n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MAE | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 193 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 50.4 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "123 K Trainable params\n", + "0 Non-trainable params\n", + "123 K Total params\n", + "0.495 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n", + "INFO:lightning.pytorch.callbacks.model_summary:\n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MAE | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 193 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 50.4 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "123 K Trainable params\n", + "0 Non-trainable params\n", + "123 K Total params\n", + "0.495 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n" + ] }, { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "RkgOT4kiy_RU" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4ecdea6764d145118ab53e59451d2b0c", + "version_major": 2, + "version_minor": 0 }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/aryan/pytorch-forecasting/pytorch_forecasting/models/base/_base_model.py:28: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from tqdm.autonotebook import tqdm\n" - ] - } - ], - "source": [ - "from pytorch_forecasting.data.examples import load_toydata" + "text/plain": [ + "Sanity Checking: | | 0/? [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
0000.1677120.17215401.0000000.3005090
1010.1721540.46723300.9950040.3005090
2020.4672330.55495200.9800670.3005090
3030.5549520.74652900.9553360.3005090
4040.7465290.71174500.9210610.3005090
\n", - "" - ], - "text/plain": [ - " series_id time_idx x y category future_known_feature \\\n", - "0 0 0 0.167712 0.172154 0 1.000000 \n", - "1 0 1 0.172154 0.467233 0 0.995004 \n", - "2 0 2 0.467233 0.554952 0 0.980067 \n", - "3 0 3 0.554952 0.746529 0 0.955336 \n", - "4 0 4 0.746529 0.711745 0 0.921061 \n", - "\n", - " static_feature static_feature_cat \n", - "0 0.300509 0 \n", - "1 0.300509 0 \n", - "2 0.300509 0 \n", - "3 0.300509 0 \n", - "4 0.300509 0 " - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "num_series = 100 # Number of individual time series to generate\n", - "seq_length = 50 # Length of each time series\n", - "data_df = load_toydata(num_series, seq_length)\n", - "data_df.head()" + "text/plain": [ + "Training: | | 0/? [00:00 dict: + """ + Loads configuration from a dictionary, YAML file, or Pickle file. + """ + if config is None: + if ckpt_path and auto_file_name: + path = Path(ckpt_path).parent / auto_file_name + if path.exists(): + with open(path, "rb") as f: + return pickle.load(f) # noqa : S301 + return {} + + if isinstance(config, dict): + return config + + path = Path(config) + if not path.exists(): + raise FileNotFoundError(f"Configuration file not found: {path}") + + suffix = path.suffix.lower() + print(suffix) + + if suffix in [".yaml", ".yml"]: + with open(path) as f: + return yaml.safe_load(f) or {} + + else: + raise ValueError( + f"Unsupported config format: {suffix}. Use .yaml, .yml, or .pkl" + ) + + @classmethod + def get_cls(cls): + """Get the underlying model class.""" + raise NotImplementedError("Subclasses must implement `get_cls`.") + + @classmethod + def get_datamodule_cls(cls): + """Get the underlying DataModule class.""" + raise NotImplementedError("Subclasses must implement `get_datamodule_cls`.") + + @classmethod + def get_test_dataset_from(cls, **kwargs): + """ + Creates and returns D1 TimeSeries dataSet objects for testing. + """ + from pytorch_forecasting.tests._data_scenarios import ( + data_with_covariates_v2, + make_datasets_v2, + ) + + raw_data = data_with_covariates_v2() + + datasets_info = make_datasets_v2(raw_data, **kwargs) + + return { + "train": datasets_info["training_dataset"], + "predict": datasets_info["validation_dataset"], + } + + def _build_model(self, metadata: dict, **kwargs): + """Instantiates the model, either from a checkpoint or from config.""" + model_cls = self.get_cls() + if self.ckpt_path: + self.model = model_cls.load_from_checkpoint( + self.ckpt_path, metadata=metadata, **kwargs + ) + elif self.model_cfg: + self.model = model_cls(**self.model_cfg, metadata=metadata) + else: + self.model = None + + def _build_datamodule(self, data: TimeSeries) -> LightningDataModule: + """Constructs a DataModule from a D1 layer object.""" + if not self.datamodule_cfg: + raise ValueError("`datamodule_cfg` must be provided to build a datamodule.") + datamodule_cls = self.get_datamodule_cls() + return datamodule_cls(data, **self.datamodule_cfg) + + def _load_dataloader( + self, data: Union[TimeSeries, LightningDataModule, DataLoader] + ) -> DataLoader: + """Converts various data input types into a DataLoader for prediction.""" + if isinstance(data, TimeSeries): # D1 Layer + dm = self._build_datamodule(data) + dm.setup(stage="predict") + return dm.predict_dataloader() + elif isinstance(data, LightningDataModule): # D2 Layer + data.setup(stage="predict") + return data.predict_dataloader() + elif isinstance(data, DataLoader): + return data + else: + raise TypeError( + f"Unsupported data type for prediction: {type(data).__name__}. " + "Expected TimeSeriesDataSet, LightningDataModule, or DataLoader." + ) + + def _save_artifact(self, output_dir: Path): + """Save all configuration artifacts.""" + output_dir.mkdir(parents=True, exist_ok=True) + + with open(output_dir / "datamodule_cfg.pkl", "wb") as f: + pickle.dump(self.datamodule_cfg, f) + + with open(output_dir / "model_cfg.pkl", "wb") as f: + pickle.dump(self.model_cfg, f) + + if self.datamodule is not None and hasattr(self.datamodule, "metadata"): + with open(output_dir / "metadata.pkl", "wb") as f: + pickle.dump(self.datamodule.metadata, f) + + def fit( + self, + data: Union[TimeSeries, LightningDataModule], + # todo: we should create a base data_module for different data_modules + save_ckpt: bool = True, + ckpt_dir: Union[str, Path] = "checkpoints", + ckpt_kwargs: Optional[dict[str, Any]] = None, + **trainer_fit_kwargs, + ): + """ + Fit the model to the training data. + + Parameters + ---------- + data : Union[TimeSeries, LightningDataModule] + The data to fit on (D1 or D2 layer). This object is responsible + for providing both training and validation data. + save_ckpt : bool, default=True + If True, save the best model checkpoint and the `datamodule_cfg`. + ckpt_dir : Union[str, Path], default="checkpoints" + Directory to save artifacts. + ckpt_kwargs : dict, optional + Keyword arguments passed to ``ModelCheckpoint``. + **trainer_fit_kwargs : + Additional keyword arguments passed to `trainer.fit()`. + + Returns + ------- + Optional[Path] + The path to the best model checkpoint if `save_ckpt=True`, else None. + """ + if isinstance(data, TimeSeries): + self.datamodule = self._build_datamodule(data) + else: + self.datamodule = data + self.datamodule.setup(stage="fit") + + if self.model is None: + if not self.model_cfg: + raise RuntimeError( + "`model_cfg` must be provided to train from scratch." + ) + metadata = self.datamodule.metadata + self._build_model(metadata) + + callbacks = self.trainer_cfg.get("callbacks", []).copy() + checkpoint_cb = None + if save_ckpt: + ckpt_dir = Path(ckpt_dir) + ckpt_dir.mkdir(parents=True, exist_ok=True) + default_ckpt_kwargs = { + "dirpath": ckpt_dir, + "filename": "best-{epoch}-{step}", + "save_top_k": 1, + "monitor": "val_loss", + "mode": "min", + } + if ckpt_kwargs: + default_ckpt_kwargs.update(ckpt_kwargs) + checkpoint_cb = ModelCheckpoint(**default_ckpt_kwargs) + callbacks.append(checkpoint_cb) + trainer_init_cfg = self.trainer_cfg.copy() + trainer_init_cfg.pop("callbacks", None) + + self.trainer = Trainer(**trainer_init_cfg, callbacks=callbacks) + + self.trainer.fit(self.model, datamodule=self.datamodule, **trainer_fit_kwargs) + if save_ckpt and checkpoint_cb: + best_model_path = Path(checkpoint_cb.best_model_path) + self._save_artifact(best_model_path.parent) + print(f"Artifacts saved in: {best_model_path.parent}") + return best_model_path + return None + + def predict( + self, + data: Union[TimeSeries, LightningDataModule, DataLoader], + output_dir: Optional[Union[str, Path]] = None, + **kwargs, + ) -> Union[dict[str, torch.Tensor], None]: + """ + Generate predictions by wrapping the model's predict method. + + This method prepares the data by resolving it into a DataLoader and then + delegates the prediction task to the underlying model's ``.predict()`` method. + + Parameters + ---------- + data : Union[TimeSeries, LightningDataModule, DataLoader] + The data to predict on (D1, D2, or DataLoader). + **kwargs : + Additional keyword arguments passed directly to the model's ``.predict()`` + method. This includes `mode`, `return_info`, `output_dir`, and any + `trainer_kwargs`. + + Returns + ------- + Union[Dict[str, torch.Tensor], None] + A dictionary of prediction tensors, or `None` if `output_dir` is specified + in `**kwargs`. + """ + if self.model is None: + raise RuntimeError( + "Model is not initialized. Provide `model_cfg` or `ckpt_path`." + ) + + dataloader = self._load_dataloader(data) + predictions = self.model.predict(dataloader, **kwargs) + + if output_dir: + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + output_file = output_path / "predictions.pkl" + with open(output_file, "wb") as f: + pickle.dump(predictions, f) + print(f"Predictions saved to {output_file}") + return None + + return predictions diff --git a/pytorch_forecasting/callbacks/__init__.py b/pytorch_forecasting/callbacks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pytorch_forecasting/callbacks/predict.py b/pytorch_forecasting/callbacks/predict.py new file mode 100644 index 000000000..0d4dab719 --- /dev/null +++ b/pytorch_forecasting/callbacks/predict.py @@ -0,0 +1,111 @@ +from typing import Any, Optional +from warnings import warn + +from lightning import Trainer +from lightning.pytorch import LightningModule +from lightning.pytorch.callbacks import BasePredictionWriter +import torch + + +class PredictCallback(BasePredictionWriter): + """ + Callback to capture predictions and related information internally. + + This callback is used by ``BaseModel.predict()`` to process raw model outputs + into the desired format (``prediction``, ``quantiles``, or ``raw``) and collect + any additional requested info (``x``, ``y``, ``index``, etc.). The results are + collated and stored in memory, accessible via the ``.result`` property. + + Parameters + ---------- + mode : str + The prediction mode ("prediction", "quantiles", or "raw"). + return_info : list[str], optional + Additional information to return. + **kwargs : + Additional keyword arguments for `to_prediction` or `to_quantiles`. + """ + + def __init__( + self, + mode: str = "prediction", + return_info: Optional[list[str]] = None, + mode_kwargs: dict[str, Any] = None, + ): + super().__init__(write_interval="epoch") + self.mode = mode + self.return_info = return_info or [] + self.mode_kwargs = mode_kwargs or {} + self._reset_data() + + def _reset_data(self, result: bool = True): + """Clear collected data for a new prediction run.""" + self.predictions = [] + self.info = {key: [] for key in self.return_info} + if result: + self._result = None + + def on_predict_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ): + """Process and store predictions for a single batch.""" + x, y = batch + + if self.mode == "raw": + processed_output = outputs + elif self.mode == "prediction": + processed_output = pl_module.to_prediction(outputs, **self.mode_kwargs) + elif self.mode == "quantiles": + processed_output = pl_module.to_quantiles(outputs, **self.mode_kwargs) + else: + raise ValueError(f"Invalid prediction mode: {self.mode}") + + self.predictions.append(processed_output) + + for key in self.return_info: + if key == "x": + self.info[key].append(x) + elif key == "y": + self.info[key].append(y[0]) + elif key == "index": + self.info[key].append(y[1]) + elif key == "decoder_lengths": + self.info[key].append(x["decoder_lengths"]) + else: + warn(f"Unknown return_info key: {key}") + + def on_predict_epoch_end(self, trainer: Trainer, pl_module: LightningModule): + """Collate all batch results into final tensors.""" + if self.mode == "raw" and isinstance(self.predictions[0], dict): + keys = self.predictions[0].keys() + collated_preds = { + key: torch.cat([p[key] for p in self.predictions]) for key in keys + } + else: + collated_preds = {"prediction": torch.cat(self.predictions)} + + final_result = collated_preds + + for key, data_list in self.info.items(): + if isinstance(data_list[0], dict): + collated_info = { + k: torch.cat([d[k] for d in data_list]) for k in data_list[0].keys() + } + else: + collated_info = torch.cat(data_list) + final_result[key] = collated_info + + self._result = final_result + self._reset_data(result=False) + + @property + def result(self) -> dict[str, torch.Tensor]: + if self._result is None: + raise RuntimeError("Prediction results are not yet available.") + return self._result diff --git a/pytorch_forecasting/models/base/_base_model_v2.py b/pytorch_forecasting/models/base/_base_model_v2.py index 8896a5397..e0affe943 100644 --- a/pytorch_forecasting/models/base/_base_model_v2.py +++ b/pytorch_forecasting/models/base/_base_model_v2.py @@ -5,16 +5,19 @@ ######################################################################################## -from typing import Optional, Union +from typing import Any, Optional, Union from warnings import warn +from lightning import Trainer from lightning.pytorch import LightningModule from lightning.pytorch.utilities.types import STEP_OUTPUT import torch import torch.nn as nn from torch.optim import Optimizer +from torch.utils.data import DataLoader -from pytorch_forecasting.metrics import Metric +from pytorch_forecasting.callbacks.predict import PredictCallback +from pytorch_forecasting.metrics import Metric, MultiLoss from pytorch_forecasting.utils._classproperty import classproperty @@ -91,6 +94,69 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ raise NotImplementedError("Forward method must be implemented by subclass.") + def predict( + self, + dataloader: DataLoader, + mode: str = "prediction", + return_info: Optional[list[str]] = None, + mode_kwargs: dict[str, Any] = None, + trainer_kwargs: dict[str, Any] = None, + ) -> dict[str, torch.Tensor]: + """ + Generate predictions for new data using the `lightning.Trainer`. + + Parameters + ---------- + dataloader : DataLoader + The dataloader containing the data to predict on. + mode : str + The prediction mode ("prediction", "quantiles", or "raw"). + return_info : list[str], optional + A list of additional information to return. + mode_kwargs : dict[str, Any] + Additional arguments for `to_prediction`/`to_quantiles`. + trainer_kwargs: dict[str, Any] + Additional arguments for `Trainer`. + + Returns + ------- + dict[str, torch.Tensor] + A dictionary of prediction results. + """ + trainer_kwargs = trainer_kwargs or {} + predict_callback = PredictCallback( + mode=mode, return_info=return_info, mode_kwargs=mode_kwargs + ) + + callbacks = trainer_kwargs.get("callbacks", []) + if not isinstance(callbacks, list): + callbacks = [callbacks] + callbacks.append(predict_callback) + trainer_kwargs["callbacks"] = callbacks + + trainer = Trainer(**trainer_kwargs) + trainer.predict(self, dataloaders=dataloader) + + return predict_callback.result + + def to_prediction(self, out: dict[str, Any], **kwargs) -> torch.Tensor: + """Converts raw model output to point forecasts.""" + # todo: add MultiLoss support + try: + out = self.loss.to_prediction(out["prediction"], **kwargs) + except TypeError: # in case passed kwargs do not exist + out = self.loss.to_prediction(out["prediction"]) + return out + + def to_quantiles(self, out: dict[str, Any], **kwargs) -> torch.Tensor: + """Converts raw model output to quantile forecasts.""" + # todo: add MultiLoss support + try: + out = self.loss.to_quantiles(out["prediction"], **kwargs) + except TypeError: # in case passed kwargs do not exist + out = self.loss.to_quantiles(out["prediction"]) + return out + def training_step( self, batch: tuple[dict[str, torch.Tensor]], batch_idx: int ) -> STEP_OUTPUT: diff --git a/pytorch_forecasting/models/dlinear/_dlinear_pkg_v2.py b/pytorch_forecasting/models/dlinear/_dlinear_pkg_v2.py index bf4fffce5..500446d9d 100644 --- a/pytorch_forecasting/models/dlinear/_dlinear_pkg_v2.py +++ b/pytorch_forecasting/models/dlinear/_dlinear_pkg_v2.py @@ -2,10 +2,10 @@ Packages container for DLinear model. """ -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 +from pytorch_forecasting.base._base_pkg import Base_pkg -class DLinear_pkg_v2(_BasePtForecasterV2): +class DLinear_pkg_v2(Base_pkg): """DLinear package container.""" _tags = { @@ -26,6 +26,13 @@ def get_cls(cls): return DLinear + @classmethod + def get_datamodule_cls(cls): + """Get the underlying DataModule class.""" + from pytorch_forecasting.data._tslib_data_module import TslibDataModule + + return TslibDataModule + @classmethod def _get_test_datamodule_from(cls, trainer_kwargs): """Create test dataloaders from trainer_kwargs - following v1 pattern.""" @@ -112,7 +119,7 @@ def get_test_train_params(cls): from pytorch_forecasting.metrics import MAE, MAPE, SMAPE, QuantileLoss - return [ + params = [ {}, dict(moving_avg=25, individual=False, logging_metrics=[SMAPE()]), dict( @@ -125,3 +132,13 @@ def get_test_train_params(cls): logging_metrics=[SMAPE()], ), ] + + default_dm_cfg = {"context_length": 8, "prediction_length": 2} + + for param in params: + current_dm_cfg = param.get("datamodule_cfg", {}) + default_dm_cfg.update(current_dm_cfg) + + param["datamodule_cfg"] = default_dm_cfg + + return params diff --git a/pytorch_forecasting/models/samformer/_samformer_v2_pkg.py b/pytorch_forecasting/models/samformer/_samformer_v2_pkg.py index 36db9340a..2838fcc91 100644 --- a/pytorch_forecasting/models/samformer/_samformer_v2_pkg.py +++ b/pytorch_forecasting/models/samformer/_samformer_v2_pkg.py @@ -2,10 +2,10 @@ Samformer package container. """ -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 +from pytorch_forecasting.base._base_pkg import Base_pkg -class Samformer_pkg_v2(_BasePtForecasterV2): +class Samformer_pkg_v2(Base_pkg): """Samformer package container.""" _tags = { @@ -21,83 +21,13 @@ def get_cls(cls): return Samformer @classmethod - def _get_test_datamodule_from(cls, trainer_kwargs): - """Create test dataloaders from trainer_kwargs - following v1 pattern.""" + def get_datamodule_cls(cls): + """Get the underlying DataModule class.""" from pytorch_forecasting.data.data_module import ( EncoderDecoderTimeSeriesDataModule, ) - from pytorch_forecasting.tests._data_scenarios import ( - data_with_covariates_v2, - make_datasets_v2, - ) - - data_with_covariates = data_with_covariates_v2() - - data_loader_default_kwargs = dict( - target="target", - group_ids=["agency_encoded", "sku_encoded"], - add_relative_time_idx=True, - ) - - data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {}) - data_loader_default_kwargs.update(data_loader_kwargs) - datasets_info = make_datasets_v2( - data_with_covariates, **data_loader_default_kwargs - ) - - training_dataset = datasets_info["training_dataset"] - validation_dataset = datasets_info["validation_dataset"] - training_max_time_idx = datasets_info["training_max_time_idx"] - - max_encoder_length = data_loader_kwargs.get("max_encoder_length", 4) - max_prediction_length = data_loader_kwargs.get("max_prediction_length", 3) - add_relative_time_idx = data_loader_kwargs.get("add_relative_time_idx", True) - batch_size = data_loader_kwargs.get("batch_size", 2) - - train_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=training_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - add_relative_time_idx=add_relative_time_idx, - batch_size=batch_size, - train_val_test_split=(0.8, 0.2, 0.0), - ) - - val_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=validation_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - min_prediction_idx=training_max_time_idx, - add_relative_time_idx=add_relative_time_idx, - batch_size=batch_size, - train_val_test_split=(0.0, 1.0, 0.0), - ) - - test_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=validation_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - min_prediction_idx=training_max_time_idx, - add_relative_time_idx=add_relative_time_idx, - batch_size=1, - train_val_test_split=(0.0, 0.0, 1.0), - ) - - train_datamodule.setup("fit") - val_datamodule.setup("fit") - test_datamodule.setup("test") - - train_dataloader = train_datamodule.train_dataloader() - val_dataloader = val_datamodule.val_dataloader() - test_dataloader = test_datamodule.test_dataloader() - - return { - "train": train_dataloader, - "val": val_dataloader, - "test": test_dataloader, - "data_module": train_datamodule, - } + return EncoderDecoderTimeSeriesDataModule @classmethod def get_test_train_params(cls): @@ -115,7 +45,7 @@ def get_test_train_params(cls): from pytorch_forecasting.metrics import QuantileLoss - return [ + params = [ { # "loss": nn.MSELoss(), "hidden_size": 32, @@ -134,3 +64,13 @@ def get_test_train_params(cls): "use_revin": False, }, ] + + default_dm_cfg = {"max_encoder_length": 4, "max_prediction_length": 3} + + for param in params: + current_dm_cfg = param.get("datamodule_cfg", {}) + default_dm_cfg.update(current_dm_cfg) + + param["datamodule_cfg"] = default_dm_cfg + + return params diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py index 8c95daa6b..d121eba6e 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py @@ -1,9 +1,9 @@ """TFT package container.""" -from pytorch_forecasting.models.base import _BasePtForecasterV2 +from pytorch_forecasting.base._base_pkg import Base_pkg -class TFT_pkg_v2(_BasePtForecasterV2): +class TFT_pkg_v2(Base_pkg): """TFT package container.""" _tags = { @@ -23,83 +23,13 @@ def get_cls(cls): return TFT @classmethod - def _get_test_datamodule_from(cls, trainer_kwargs): - """Create test dataloaders from trainer_kwargs - following v1 pattern.""" + def get_datamodule_cls(cls): + """Get the underlying DataModule class.""" from pytorch_forecasting.data.data_module import ( EncoderDecoderTimeSeriesDataModule, ) - from pytorch_forecasting.tests._data_scenarios import ( - data_with_covariates_v2, - make_datasets_v2, - ) - - data_with_covariates = data_with_covariates_v2() - - data_loader_default_kwargs = dict( - target="target", - group_ids=["agency_encoded", "sku_encoded"], - add_relative_time_idx=True, - ) - - data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {}) - data_loader_default_kwargs.update(data_loader_kwargs) - - datasets_info = make_datasets_v2( - data_with_covariates, **data_loader_default_kwargs - ) - - training_dataset = datasets_info["training_dataset"] - validation_dataset = datasets_info["validation_dataset"] - training_max_time_idx = datasets_info["training_max_time_idx"] - - max_encoder_length = data_loader_kwargs.get("max_encoder_length", 4) - max_prediction_length = data_loader_kwargs.get("max_prediction_length", 3) - add_relative_time_idx = data_loader_kwargs.get("add_relative_time_idx", True) - batch_size = data_loader_kwargs.get("batch_size", 2) - train_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=training_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - add_relative_time_idx=add_relative_time_idx, - batch_size=batch_size, - train_val_test_split=(0.8, 0.2, 0.0), - ) - - val_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=validation_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - min_prediction_idx=training_max_time_idx, - add_relative_time_idx=add_relative_time_idx, - batch_size=batch_size, - train_val_test_split=(0.0, 1.0, 0.0), - ) - - test_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=validation_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - min_prediction_idx=training_max_time_idx, - add_relative_time_idx=add_relative_time_idx, - batch_size=1, - train_val_test_split=(0.0, 0.0, 1.0), - ) - - train_datamodule.setup("fit") - val_datamodule.setup("fit") - test_datamodule.setup("test") - - train_dataloader = train_datamodule.train_dataloader() - val_dataloader = val_datamodule.val_dataloader() - test_dataloader = test_datamodule.test_dataloader() - - return { - "train": train_dataloader, - "val": val_dataloader, - "test": test_dataloader, - "data_module": train_datamodule, - } + return EncoderDecoderTimeSeriesDataModule @classmethod def get_test_train_params(cls): @@ -113,19 +43,17 @@ def get_test_train_params(cls): `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. `create_test_instance` uses the first (or only) dictionary in `params` """ - return [ + params = [ {}, dict( hidden_size=25, attention_head_size=5, ), - dict( - data_loader_kwargs=dict(max_encoder_length=5, max_prediction_length=3) - ), + dict(datamodule_cfg=dict(max_encoder_length=5, max_prediction_length=3)), dict( hidden_size=24, attention_head_size=8, - data_loader_kwargs=dict( + datamodule_cfg=dict( max_encoder_length=5, max_prediction_length=3, add_relative_time_idx=False, @@ -133,7 +61,17 @@ def get_test_train_params(cls): ), dict( hidden_size=12, - data_loader_kwargs=dict(max_encoder_length=7, max_prediction_length=10), + datamodule_cfg=dict(max_encoder_length=7, max_prediction_length=10), ), dict(attention_head_size=2), ] + + default_dm_cfg = {"max_encoder_length": 4, "max_prediction_length": 3} + + for param in params: + current_dm_cfg = param.get("datamodule_cfg", {}) + default_dm_cfg.update(current_dm_cfg) + + param["datamodule_cfg"] = default_dm_cfg + + return params diff --git a/pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2.py b/pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2.py index 70620928e..f49caaf59 100644 --- a/pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2.py +++ b/pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2.py @@ -79,7 +79,7 @@ def __init__( """ super().__init__(loss=loss) - self.save_hyperparameters(logger=False) + self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"]) self.dropout = dropout_rate self.persistence_weight = persistence_weight diff --git a/pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2_pkg.py b/pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2_pkg.py index d3cf70454..6b2780053 100644 --- a/pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2_pkg.py +++ b/pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2_pkg.py @@ -1,9 +1,9 @@ """TIDE package container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 +from pytorch_forecasting.base._base_pkg import Base_pkg -class TIDE_pkg_v2(_BasePtForecasterV2): +class TIDE_pkg_v2(Base_pkg): """TIDE package container.""" _tags = { @@ -19,83 +19,13 @@ def get_cls(cls): return TIDE @classmethod - def _get_test_datamodule_from(cls, trainer_kwargs): - """Create test dataloaders from trainer_kwargs - following v1 pattern.""" + def get_datamodule_cls(cls): + """Get the underlying DataModule class.""" from pytorch_forecasting.data.data_module import ( EncoderDecoderTimeSeriesDataModule, ) - from pytorch_forecasting.tests._data_scenarios import ( - data_with_covariates_v2, - make_datasets_v2, - ) - - data_with_covariates = data_with_covariates_v2() - - data_loader_default_kwargs = dict( - target="target", - group_ids=["agency_encoded", "sku_encoded"], - add_relative_time_idx=True, - ) - - data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {}) - data_loader_default_kwargs.update(data_loader_kwargs) - - datasets_info = make_datasets_v2( - data_with_covariates, **data_loader_default_kwargs - ) - - training_dataset = datasets_info["training_dataset"] - validation_dataset = datasets_info["validation_dataset"] - training_max_time_idx = datasets_info["training_max_time_idx"] - - max_encoder_length = data_loader_kwargs.get("max_encoder_length", 4) - max_prediction_length = data_loader_kwargs.get("max_prediction_length", 3) - add_relative_time_idx = data_loader_kwargs.get("add_relative_time_idx", True) - batch_size = data_loader_kwargs.get("batch_size", 2) - - train_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=training_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - add_relative_time_idx=add_relative_time_idx, - batch_size=batch_size, - train_val_test_split=(0.8, 0.2, 0.0), - ) - - val_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=validation_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - min_prediction_idx=training_max_time_idx, - add_relative_time_idx=add_relative_time_idx, - batch_size=batch_size, - train_val_test_split=(0.0, 1.0, 0.0), - ) - - test_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=validation_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - min_prediction_idx=training_max_time_idx, - add_relative_time_idx=add_relative_time_idx, - batch_size=1, - train_val_test_split=(0.0, 0.0, 1.0), - ) - train_datamodule.setup("fit") - val_datamodule.setup("fit") - test_datamodule.setup("test") - - train_dataloader = train_datamodule.train_dataloader() - val_dataloader = val_datamodule.val_dataloader() - test_dataloader = test_datamodule.test_dataloader() - - return { - "train": train_dataloader, - "val": val_dataloader, - "test": test_dataloader, - "data_module": train_datamodule, - } + return EncoderDecoderTimeSeriesDataModule @classmethod def get_test_train_params(cls): @@ -111,7 +41,7 @@ def get_test_train_params(cls): """ from pytorch_forecasting.metrics import MAE, MAPE - return [ + params = [ dict( hidden_size=16, d_model=8, @@ -125,7 +55,7 @@ def get_test_train_params(cls): n_add_enc=2, n_add_dec=2, dropout_rate=0.2, - data_loader_kwargs=dict(max_encoder_length=5, max_prediction_length=3), + datamodule_cfg=dict(max_encoder_length=5, max_prediction_length=3), loss=MAE(), ), dict( @@ -134,7 +64,16 @@ def get_test_train_params(cls): n_add_enc=3, n_add_dec=2, dropout_rate=0.1, - data_loader_kwargs=dict(max_encoder_length=4, max_prediction_length=2), + datamodule_cfg=dict(max_encoder_length=4, max_prediction_length=2), loss=MAPE(), ), ] + default_dm_cfg = {"max_encoder_length": 4, "max_prediction_length": 3} + + for param in params: + current_dm_cfg = param.get("datamodule_cfg", {}) + default_dm_cfg.update(current_dm_cfg) + + param["datamodule_cfg"] = default_dm_cfg + + return params diff --git a/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py b/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py index 74b27227f..2bb377cc4 100644 --- a/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py +++ b/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py @@ -2,10 +2,10 @@ Metadata container for TimeXer v2. """ -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 +from pytorch_forecasting.base._base_pkg import Base_pkg -class TimeXer_pkg_v2(_BasePtForecasterV2): +class TimeXer_pkg_v2(Base_pkg): """TimeXer metadata container.""" _tags = { @@ -25,77 +25,11 @@ def get_cls(cls): return TimeXer @classmethod - def _get_test_datamodule_from(cls, trainer_kwargs): - """Create test dataloaders from trainer_kwargs - following v1 pattern.""" + def get_datamodule_cls(cls): + """Get the underlying DataModule class.""" from pytorch_forecasting.data._tslib_data_module import TslibDataModule - from pytorch_forecasting.tests._data_scenarios import ( - data_with_covariates_v2, - make_datasets_v2, - ) - data_with_covariates = data_with_covariates_v2() - - data_loader_default_kwargs = dict( - target="target", - group_ids=["agency_encoded", "sku_encoded"], - add_relative_time_idx=True, - ) - - data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {}) - data_loader_default_kwargs.update(data_loader_kwargs) - - datasets_info = make_datasets_v2( - data_with_covariates, **data_loader_default_kwargs - ) - - training_dataset = datasets_info["training_dataset"] - validation_dataset = datasets_info["validation_dataset"] - - context_length = data_loader_kwargs.get("context_length", 12) - prediction_length = data_loader_kwargs.get("prediction_length", 4) - batch_size = data_loader_kwargs.get("batch_size", 2) - - train_datamodule = TslibDataModule( - time_series_dataset=training_dataset, - context_length=context_length, - prediction_length=prediction_length, - add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), - batch_size=batch_size, - train_val_test_split=(0.8, 0.2, 0.0), - ) - - val_datamodule = TslibDataModule( - time_series_dataset=validation_dataset, - context_length=context_length, - prediction_length=prediction_length, - add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), - batch_size=batch_size, - train_val_test_split=(0.0, 1.0, 0.0), - ) - - test_datamodule = TslibDataModule( - time_series_dataset=validation_dataset, - context_length=context_length, - prediction_length=prediction_length, - add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), - batch_size=1, - train_val_test_split=(0.0, 0.0, 1.0), - ) - - train_datamodule.setup("fit") - val_datamodule.setup("fit") - test_datamodule.setup("test") - - train_dataloader = train_datamodule.train_dataloader() - val_dataloader = val_datamodule.val_dataloader() - test_dataloader = test_datamodule.test_dataloader() - - return { - "train": train_dataloader, - "val": val_dataloader, - "test": test_dataloader, - "data_module": train_datamodule, - } + return TslibDataModule @classmethod def get_test_train_params(cls): @@ -111,17 +45,17 @@ def get_test_train_params(cls): """ from pytorch_forecasting.metrics import QuantileLoss - return [ + params = [ {}, dict( hidden_size=64, n_heads=4, ), - dict(data_loader_kwargs=dict(context_length=12, prediction_length=3)), + dict(datamodule_cfg=dict(context_length=12, prediction_length=3)), dict( hidden_size=32, n_heads=2, - data_loader_kwargs=dict( + datamodule_cfg=dict( context_length=12, prediction_length=3, add_relative_time_idx=False, @@ -130,7 +64,7 @@ def get_test_train_params(cls): dict( hidden_size=128, patch_length=12, - data_loader_kwargs=dict(context_length=16, prediction_length=4), + datamodule_cfg=dict(context_length=16, prediction_length=4), ), dict( n_heads=2, @@ -156,7 +90,7 @@ def get_test_train_params(cls): factor=2, activation="relu", dropout=0.05, - data_loader_kwargs=dict( + datamodule_cfg=dict( context_length=16, prediction_length=4, ), @@ -172,3 +106,12 @@ def get_test_train_params(cls): use_efficient_attention=True, ), ] + default_dm_cfg = {"context_length": 12, "prediction_length": 4} + + for param in params: + current_dm_cfg = param.get("datamodule_cfg", {}) + default_dm_cfg.update(current_dm_cfg) + + param["datamodule_cfg"] = default_dm_cfg + + return params diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py deleted file mode 100644 index 9c28c5d0a..000000000 --- a/pytorch_forecasting/tests/test_all_estimators_v2.py +++ /dev/null @@ -1,137 +0,0 @@ -"""Automated tests based on the skbase test suite template.""" - -import shutil - -import lightning.pytorch as pl -from lightning.pytorch.callbacks import EarlyStopping -from lightning.pytorch.loggers import TensorBoardLogger -import torch -import torch.nn as nn - -from pytorch_forecasting.metrics import SMAPE -from pytorch_forecasting.tests.test_all_estimators import ( - EstimatorFixtureGenerator, - EstimatorPackageConfig, -) - -# whether to test only estimators from modules that are changed w.r.t. main -# default is False, can be set to True by pytest --only_changed_modules True flag -ONLY_CHANGED_MODULES = False - - -def _integration( - estimator_cls, - dataloaders, - tmp_path, - data_loader_kwargs={}, - clip_target: bool = False, - trainer_kwargs=None, - **kwargs, -): - train_dataloader = dataloaders["train"] - val_dataloader = dataloaders["val"] - test_dataloader = dataloaders["test"] - - early_stop_callback = EarlyStopping( - monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" - ) - - logger = TensorBoardLogger(tmp_path) - if trainer_kwargs is None: - trainer_kwargs = {} - trainer = pl.Trainer( - max_epochs=3, - gradient_clip_val=0.1, - callbacks=[early_stop_callback], - enable_checkpointing=True, - default_root_dir=tmp_path, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - logger=logger, - **trainer_kwargs, - ) - training_data_module = dataloaders.get("data_module") - metadata = training_data_module.metadata - - assert isinstance( - metadata, dict - ), f"Expected metadata to be dict, got {type(metadata)}" - - if "loss" in kwargs: - loss = kwargs["loss"] - kwargs.pop("loss") - else: - loss = SMAPE() - - net = estimator_cls( - metadata=metadata, - loss=loss, - **kwargs, - ) - - trainer.fit( - net, - train_dataloaders=train_dataloader, - val_dataloaders=val_dataloader, - ) - test_outputs = trainer.test(net, dataloaders=test_dataloader) - assert len(test_outputs) > 0 - - # todo: add the predict pipeline and make this test cleaner - x, y = next(iter(test_dataloader)) - net.eval() - with torch.no_grad(): - output = net(x) - net.train() - prediction = output["prediction"] - n_dims = prediction.ndim - assert n_dims == 3, ( - f"Prediction output must be 3D, but got {n_dims}D tensor " - f"with shape {output.shape}" - ) - - shutil.rmtree(tmp_path, ignore_errors=True) - - -class TestAllPtForecastersV2(EstimatorPackageConfig, EstimatorFixtureGenerator): - """Generic tests for all objects in the mini package.""" - - object_type_filter = "forecaster_pytorch_v2" - - def test_doctest_examples(self, object_class): - """Runs doctests for estimator class.""" - from skbase.utils.doctest_run import run_doctest - - run_doctest(object_class, name=f"class {object_class.__name__}") - - def test_integration( - self, - object_pkg, - trainer_kwargs, - tmp_path, - ): - object_class = object_pkg.get_cls() - dataloaders = object_pkg._get_test_datamodule_from(trainer_kwargs) - - _integration(object_class, dataloaders, tmp_path, **trainer_kwargs) - - def test_pkg_linkage(self, object_pkg, object_class): - """Test that the package is linked correctly.""" - # check name method - msg = ( - f"Package {object_pkg}.name() does not match class " - f"name {object_class.__name__}. " - "The expected package name is " - f"{object_class.__name__}_pkg." - ) - assert object_pkg.name() == object_class.__name__, msg - - # check naming convention - msg = ( - f"Package {object_pkg.__name__} does not match class " - f"name {object_class.__name__}. " - "The expected package name is " - f"{object_class.__name__}_pkg." - ) - assert object_pkg.__name__ == object_class.__name__ + "_pkg_v2", msg diff --git a/pytorch_forecasting/tests/test_all_v2/__init__.py b/pytorch_forecasting/tests/test_all_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pytorch_forecasting/tests/test_all_v2/_test_integration.py b/pytorch_forecasting/tests/test_all_v2/_test_integration.py new file mode 100644 index 000000000..83c1bdc78 --- /dev/null +++ b/pytorch_forecasting/tests/test_all_v2/_test_integration.py @@ -0,0 +1,35 @@ +from typing import Any + +import torch + +from pytorch_forecasting.base._base_pkg import Base_pkg +from pytorch_forecasting.data import TimeSeries + + +def _integration( + pkg: Base_pkg, + test_data: dict[str, TimeSeries], + datamodule_cfg: dict[str, Any], + **kwargs, +): + """Test integration of models with the `TimeSeries` and datamodules""" + pkg.fit(test_data["train"]) + + predictions = pkg.predict( + test_data["predict"], + mode="raw", + ) + assert predictions is not None + assert isinstance(predictions, dict) + assert "prediction" in predictions + + pred_tensor = predictions["prediction"] + assert isinstance(pred_tensor, torch.Tensor) + assert pred_tensor.ndim == 3, f"Prediction must be 3D, got {pred_tensor.ndim}D" + + expected_pred_len = datamodule_cfg.get("prediction_length") + if expected_pred_len: + assert pred_tensor.shape[1] == expected_pred_len, ( + f"Pred length mismatch: expected {expected_pred_len}, " + f"got {pred_tensor.shape[1]}" + ) diff --git a/pytorch_forecasting/tests/test_all_v2/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_v2/test_all_estimators_v2.py new file mode 100644 index 000000000..87123028b --- /dev/null +++ b/pytorch_forecasting/tests/test_all_v2/test_all_estimators_v2.py @@ -0,0 +1,134 @@ +"""Automated tests based on the skbase test suite template.""" + +import os +from pathlib import Path +import shutil + +import torch + +from pytorch_forecasting.tests.test_all_estimators import ( + EstimatorFixtureGenerator, + EstimatorPackageConfig, +) +from pytorch_forecasting.tests.test_all_v2._test_integration import _integration +from pytorch_forecasting.tests.test_all_v2.utils import _setup_pkg_and_data + +# whether to test only estimators from modules that are changed w.r.t. main +# default is False, can be set to True by pytest --only_changed_modules True flag +ONLY_CHANGED_MODULES = False + + +class TestAllPtForecastersV2(EstimatorPackageConfig, EstimatorFixtureGenerator): + """Generic tests for all objects in the mini package.""" + + object_type_filter = "forecaster_pytorch_v2" + + def test_doctest_examples(self, object_class): + """Runs doctests for estimator class.""" + from skbase.utils.doctest_run import run_doctest + + run_doctest(object_class, name=f"class {object_class.__name__}") + + def test_integration( + self, + object_pkg, + trainer_kwargs, + tmp_path, + ): + pkg, test_data, dm_cfg = _setup_pkg_and_data( + object_pkg, trainer_kwargs, tmp_path + ) + + _integration(pkg, test_data, dm_cfg) + + shutil.rmtree(tmp_path, ignore_errors=True) + + def test_checkpointing(self, object_pkg, trainer_kwargs, tmp_path): + """Test that the package can save a checkpoint and reload from it.""" + pkg, test_data, _ = _setup_pkg_and_data(object_pkg, trainer_kwargs, tmp_path) + + ckpt_dir = Path(tmp_path) / "checkpoints" + best_model_path = pkg.fit( + test_data["train"], + save_ckpt=True, + ckpt_dir=ckpt_dir, + ckpt_kwargs={"monitor": "train_loss_epoch"}, + ) + + assert best_model_path is not None + assert os.path.exists(best_model_path) + + dm_cfg_path = Path(best_model_path).parent / "model_cfg.pkl" + assert ( + dm_cfg_path.exists() + ), "datamodule_cfg.pkl was not saved alongside checkpoint" + + pkg_loaded = object_pkg(ckpt_path=best_model_path) + + predictions = pkg_loaded.predict(test_data["predict"], mode="prediction") + + assert predictions is not None + assert "prediction" in predictions + shutil.rmtree(tmp_path, ignore_errors=True) + + def test_predict_modes(self, object_pkg, trainer_kwargs, tmp_path): + """Test different prediction modes and return_info.""" + pkg, test_data, _ = _setup_pkg_and_data(object_pkg, trainer_kwargs, tmp_path) + + pkg.fit(test_data["train"], save_ckpt=False) + predict_data = test_data["predict"] + + # mode="raw" + raw_out = pkg.predict(predict_data, mode="raw") + raw_pred_tensor = raw_out["prediction"] + assert any(isinstance(v, torch.Tensor) for v in raw_out.values()) + assert ( + raw_pred_tensor.ndim == 3 + ), f"Prediction must be 3D, got {raw_pred_tensor.ndim}D" + + # mode="quantiles" + quantile_out = pkg.predict(predict_data, mode="quantiles") + quanitle_pred_tensor = quantile_out["prediction"] + assert isinstance(quanitle_pred_tensor, torch.Tensor) + assert ( + quanitle_pred_tensor.ndim == 3 + ), f"Prediction must be 3D, got {quanitle_pred_tensor.ndim}D" + + # mode="prediction" + pred_out = pkg.predict(predict_data, mode="prediction") + pred_tensor = pred_out["prediction"] + assert isinstance(pred_tensor, torch.Tensor) + assert pred_tensor.ndim == 2, f"Prediction must be 3D, got {pred_tensor.ndim}D" + + return_info_keys = ["index", "x"] + info_out = pkg.predict( + predict_data, mode="prediction", return_info=return_info_keys + ) + + for key in return_info_keys: + assert key in info_out, f"Requested key '{key}' missing from output" + + assert info_out["index"] is not None + assert isinstance(info_out["x"], dict) + + shutil.rmtree(tmp_path, ignore_errors=True) + + def test_pkg_linkage(self, object_pkg, object_class): + """Test that the package is linked correctly.""" + # check name method + msg = ( + f"Package {object_pkg}.name() does not match class " + f"name {object_class.__name__}. " + "The expected package name is " + f"{object_class.__name__}_pkg." + ) + assert object_pkg.name() == object_class.__name__, msg + + # check naming convention + msg = ( + f"Package {object_pkg.__name__} does not match class " + f"name {object_class.__name__}. " + "The expected package name is " + f"{object_class.__name__}_pkg." + ) + assert object_pkg.__name__ == object_class.__name__ + "_pkg_v2", msg diff --git a/pytorch_forecasting/tests/test_all_v2/utils.py b/pytorch_forecasting/tests/test_all_v2/utils.py new file mode 100644 index 000000000..a8cb714dc --- /dev/null +++ b/pytorch_forecasting/tests/test_all_v2/utils.py @@ -0,0 +1,61 @@ +from typing import Any + +from lightning.pytorch.loggers import TensorBoardLogger + +from pytorch_forecasting.base._base_pkg import Base_pkg +from pytorch_forecasting.data import TimeSeries +from pytorch_forecasting.metrics import SMAPE + + +def _setup_pkg_and_data( + estimator_cls: type[Base_pkg], + trainer_kwargs: dict[str, Any], + tmp_path: str, +) -> tuple[Base_pkg, dict[str, TimeSeries], dict[str, Any]]: + """ + Helper to initialize the Package, Datasets, and Configs. + + Returns + ------- + pkg : Base_pkg + The initialized model package. + test_data : dict + Dictionary containing 'train' and 'predict' TimeSeries datasets. + datamodule_cfg : dict + The final datamodule configuration used. + """ + params_copy = trainer_kwargs.copy() + datamodule_cfg = params_copy.pop("datamodule_cfg", {}) + model_cfg = params_copy + + if "loss" not in model_cfg: + model_cfg["loss"] = SMAPE() + + default_datamodule_cfg = { + "train_val_test_split": (0.8, 0.2), + "add_relative_time_idx": True, + "batch_size": 2, + } + default_datamodule_cfg.update(datamodule_cfg) + + logger = TensorBoardLogger(str(tmp_path)) + trainer_cfg = { + "max_epochs": 2, + "gradient_clip_val": 0.1, + "enable_checkpointing": True, + "default_root_dir": str(tmp_path), + "limit_train_batches": 2, + "limit_val_batches": 1, + "accelerator": "cpu", + "logger": logger, + } + + test_data = estimator_cls.get_test_dataset_from(**default_datamodule_cfg) + + pkg = estimator_cls( + model_cfg=model_cfg, + trainer_cfg=trainer_cfg, + datamodule_cfg=default_datamodule_cfg, + ) + + return pkg, test_data, default_datamodule_cfg diff --git a/pytorch_forecasting/tests/test_class_register.py b/pytorch_forecasting/tests/test_class_register.py index 2a1052125..a9699fb99 100644 --- a/pytorch_forecasting/tests/test_class_register.py +++ b/pytorch_forecasting/tests/test_class_register.py @@ -20,7 +20,9 @@ def get_test_class_registry(): keys are scitypes, values are test classes TestAll[Scitype] """ from pytorch_forecasting.tests.test_all_estimators import TestAllPtForecasters - from pytorch_forecasting.tests.test_all_estimators_v2 import TestAllPtForecastersV2 + from pytorch_forecasting.tests.test_all_v2.test_all_estimators_v2 import ( + TestAllPtForecastersV2, + ) testclass_dict = dict() testclass_dict["forecaster_pytorch_v1"] = TestAllPtForecasters