diff --git a/pytorch_forecasting/layers/_encoders/_encoder.py b/pytorch_forecasting/layers/_encoders/_encoder.py index 3b54a0838..1d2c0e773 100644 --- a/pytorch_forecasting/layers/_encoders/_encoder.py +++ b/pytorch_forecasting/layers/_encoders/_encoder.py @@ -2,39 +2,50 @@ Implementation of encoder layers from `nn.Module`. """ -import math -from math import sqrt - -import numpy as np -import torch import torch.nn as nn -import torch.nn.functional as F class Encoder(nn.Module): """ - Encoder module for the TimeXer model. + Encoder module for Tslib models. Args: layers (list): List of encoder layers. norm_layer (nn.Module, optional): Normalization layer. Defaults to None. projection (nn.Module, optional): Projection layer. Defaults to None. - """ + output_attention (Boolean, optional): Whether to output attention weights. Defaults to False. + """ # noqa: E501 - def __init__(self, layers, norm_layer=None, projection=None): + def __init__( + self, layers, norm_layer=None, projection=None, output_attention=False + ): super().__init__() self.layers = nn.ModuleList(layers) self.norm = norm_layer self.projection = projection - - def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): - for layer in self.layers: - x = layer( - x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta - ) + self.output_attention = output_attention + + def forward( + self, x, cross=None, x_mask=None, cross_mask=None, tau=None, delta=None + ): + if self.output_attention: + attns = [] + for layer in self.layers: + x, attn = layer( + x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta + ) + attns.append(attn) + else: + for layer in self.layers: + x = layer( + x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta + ) if self.norm is not None: x = self.norm(x) if self.projection is not None: x = self.projection(x) + + if self.output_attention: + return x, attns return x diff --git a/pytorch_forecasting/layers/_encoders/_encoder_layer.py b/pytorch_forecasting/layers/_encoders/_encoder_layer.py index a246edc91..0e03d437c 100644 --- a/pytorch_forecasting/layers/_encoders/_encoder_layer.py +++ b/pytorch_forecasting/layers/_encoders/_encoder_layer.py @@ -2,10 +2,6 @@ Implementation of EncoderLayer for encoder-decoder architectures from `nn.Module`. """ -import math -from math import sqrt - -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -13,25 +9,27 @@ class EncoderLayer(nn.Module): """ - Encoder layer for the TimeXer model. + Encoder layer for TsLib models. Args: self_attention (nn.Module): Self-attention mechanism. - cross_attention (nn.Module): Cross-attention mechanism. + cross_attention (nn.Module, optional): Cross-attention mechanism. d_model (int): Dimension of the model. d_ff (int, optional): Dimension of the feedforward layer. Defaults to 4 * d_model. dropout (float): Dropout rate. Defaults to 0.1. activation (str): Activation function. Defaults to "relu". - """ + output_attention (Boolean, optional): Whether to output attention weights. Defaults to False. + """ # noqa: E501 def __init__( self, self_attention, - cross_attention, - d_model, + cross_attention=None, + d_model=512, d_ff=None, dropout=0.1, activation="relu", + output_attention=False, ): super().__init__() d_ff = d_ff or 4 * d_model @@ -40,34 +38,45 @@ def __init__( self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) + if self.cross_attention is not None: + self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = F.relu if activation == "relu" else F.gelu + self.output_attention = output_attention - def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): - B, L, D = cross.shape - x = x + self.dropout( - self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0] - ) - x = self.norm1(x) - - x_glb_ori = x[:, -1, :].unsqueeze(1) - x_glb = torch.reshape(x_glb_ori, (B, -1, D)) - x_glb_attn = self.dropout( - self.cross_attention( - x_glb, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta - )[0] - ) - x_glb_attn = torch.reshape( - x_glb_attn, (x_glb_attn.shape[0] * x_glb_attn.shape[1], x_glb_attn.shape[2]) - ).unsqueeze(1) - x_glb = x_glb_ori + x_glb_attn - x_glb = self.norm2(x_glb) + def forward( + self, x, cross=None, x_mask=None, cross_mask=None, tau=None, delta=None + ): + if self.output_attention: + x, attn = self.self_attention( + x, x, x, attn_mask=x_mask, tau=tau, delta=None + ) + else: + x = self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0] + x = x + self.dropout(x) + y = x = self.norm1(x) + if self.cross_attention is not None: + B, L, D = cross.shape + x_glb_ori = x[:, -1, :].unsqueeze(1) + x_glb = torch.reshape(x_glb_ori, (B, -1, D)) + x_glb_attn = self.dropout( + self.cross_attention( + x_glb, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta + )[0] + ) + x_glb_attn = torch.reshape( + x_glb_attn, + (x_glb_attn.shape[0] * x_glb_attn.shape[1], x_glb_attn.shape[2]), + ).unsqueeze(1) + x_glb = x_glb_ori + x_glb_attn + x_glb = self.norm2(x_glb) - y = x = torch.cat([x[:, :-1, :], x_glb], dim=1) + y = x = torch.cat([x[:, :-1, :], x_glb], dim=1) y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) y = self.dropout(self.conv2(y).transpose(-1, 1)) + if self.output_attention: + return self.norm3(x + y), attn return self.norm3(x + y) diff --git a/pytorch_forecasting/models/itransformer/__init__.py b/pytorch_forecasting/models/itransformer/__init__.py new file mode 100644 index 000000000..eaa9f79f0 --- /dev/null +++ b/pytorch_forecasting/models/itransformer/__init__.py @@ -0,0 +1,13 @@ +""" +iTransformer model for forecasting time series. +""" + +from pytorch_forecasting.models.itransformer._itransformer_pkg_v2 import ( + iTransformer_pkg_v2, +) +from pytorch_forecasting.models.itransformer._itransformer_v2 import iTransformer + +__all__ = [ + "iTransformer", + "iTransformer_pkg_v2", +] diff --git a/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py b/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py new file mode 100644 index 000000000..dc923a313 --- /dev/null +++ b/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py @@ -0,0 +1,133 @@ +"""iTransformer package container v2.""" + +from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 + + +class iTransformer_pkg_v2(_BasePtForecasterV2): + """iTransformer metadata container.""" + + _tags = { + "info:name": "iTransformer", + "authors": ["JATAYU000"], + "capability:exogenous": True, + "capability:multivariate": True, + "capability:pred_int": True, + "capability:flexible_history_length": False, + "capability:cold_start": False, + } + + @classmethod + def get_cls(cls): + """Get model class.""" + from pytorch_forecasting.models.itransformer._itransformer_v2 import ( + iTransformer, + ) + + return iTransformer + + @classmethod + def _get_test_datamodule_from(cls, trainer_kwargs): + """Create test dataloaders from trainer_kwargs - following v1 pattern.""" + from pytorch_forecasting.data._tslib_data_module import TslibDataModule + from pytorch_forecasting.tests._data_scenarios import ( + data_with_covariates_v2, + make_datasets_v2, + ) + + data_with_covariates = data_with_covariates_v2() + + data_loader_default_kwargs = dict( + target="target", + group_ids=["agency_encoded", "sku_encoded"], + add_relative_time_idx=True, + ) + + data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {}) + data_loader_default_kwargs.update(data_loader_kwargs) + + datasets_info = make_datasets_v2( + data_with_covariates, **data_loader_default_kwargs + ) + + training_dataset = datasets_info["training_dataset"] + validation_dataset = datasets_info["validation_dataset"] + + context_length = data_loader_kwargs.get("context_length", 12) + prediction_length = data_loader_kwargs.get("prediction_length", 4) + batch_size = data_loader_kwargs.get("batch_size", 2) + + train_datamodule = TslibDataModule( + time_series_dataset=training_dataset, + context_length=context_length, + prediction_length=prediction_length, + add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), + batch_size=batch_size, + train_val_test_split=(0.8, 0.2, 0.0), + ) + + val_datamodule = TslibDataModule( + time_series_dataset=validation_dataset, + context_length=context_length, + prediction_length=prediction_length, + add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), + batch_size=batch_size, + train_val_test_split=(0.0, 1.0, 0.0), + ) + + test_datamodule = TslibDataModule( + time_series_dataset=validation_dataset, + context_length=context_length, + prediction_length=prediction_length, + add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), + batch_size=1, + train_val_test_split=(0.0, 0.0, 1.0), + ) + + train_datamodule.setup("fit") + val_datamodule.setup("fit") + test_datamodule.setup("test") + + train_dataloader = train_datamodule.train_dataloader() + val_dataloader = val_datamodule.val_dataloader() + test_dataloader = test_datamodule.test_dataloader() + + return { + "train": train_dataloader, + "val": val_dataloader, + "test": test_dataloader, + "data_module": train_datamodule, + } + + @classmethod + def get_test_train_params(cls): + """Get test train params.""" + from pytorch_forecasting.metrics import QuantileLoss + + return [ + {}, + dict(d_model=16, n_heads=2, e_layers=2, d_ff=64), + dict( + d_model=32, + n_heads=4, + e_layers=3, + d_ff=128, + dropout=0.1, + data_loader_kwargs=dict( + batch_size=4, context_length=8, prediction_length=4 + ), + ), + dict( + hidden_size=32, + n_heads=2, + e_layers=1, + d_ff=64, + factor=2, + activation="relu", + dropout=0.05, + data_loader_kwargs=dict( + context_length=16, + prediction_length=4, + ), + loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]), + ), + ] diff --git a/pytorch_forecasting/models/itransformer/_itransformer_v2.py b/pytorch_forecasting/models/itransformer/_itransformer_v2.py new file mode 100644 index 000000000..d1d1427cb --- /dev/null +++ b/pytorch_forecasting/models/itransformer/_itransformer_v2.py @@ -0,0 +1,222 @@ +from typing import Optional, Union + +import torch +import torch.nn as nn +from torch.optim import Optimizer + +from pytorch_forecasting.models.base._tslib_base_model_v2 import TslibBaseModel + + +class iTransformer(TslibBaseModel): + """ + An implementation of iTransformer model for v2 of pytorch-forecasting. + + iTransformer repurposes the Transformer architecture by applying attention + and feed-forward networks on inverted dimensions. Instead of treating + timestamps as tokens (like traditional Transformers), iTransformer embeds + individual time series as variate tokens. The attention mechanism captures + multivariate correlations, while the feed-forward network learns nonlinear + representations for each variate. This inversion enables better handling + of long lookback windows, improved generalization across different variates, + and state-of-the-art performance on real-world forecasting tasks without + modifying the basic Transformer components. + + Parameters + ---------- + loss: nn.Module + Loss function to use for training. + output_attention: bool, default=False + Whether to output attention weights. + factor: int, default=5 + Factor for the attention mechanism, controlling keys and values. + d_model: int, default=512 + Dimension of the model embeddings and hidden representations. + d_ff: int, default=2048 + Dimension of the feed-forward network. + activation: str, default='relu' + Activation function to use in the feed-forward network. + dropout: float, default=0.1 + Dropout rate for regularization. + n_heads: int, default=8 + Number of attention heads in the multi-head attention mechanism. + e_layers: int, default=3 + Number of encoder layers in the transformer architecture. + logging_metrics: Optional[list[nn.Module]], default=None + List of metrics to log during training, validation, and testing. + optimizer: Optional[Union[Optimizer, str]], default='adam' + Optimizer to use for training. Can be a string name or an instance of an optimizer. + optimizer_params: Optional[dict], default=None + Parameters for the optimizer. If None, default parameters for the optimizer will be used. + lr_scheduler: Optional[str], default=None + Learning rate scheduler to use. If None, no scheduler is used. + lr_scheduler_params: Optional[dict], default=None + Parameters for the learning rate scheduler. If None, default parameters for the scheduler will be used. + metadata: Optional[dict], default=None + Metadata for the model from TslibDataModule. This can include information about the dataset, + such as the number of time steps, number of features, etc. It is used to initialize the model + and ensure it is compatible with the data being used. + + References + ---------- + [1] https://arxiv.org/pdf/2310.06625 + [2] https://github.com/thuml/iTransformer/blob/main/model/iTransformer.py + + Notes + ----- + [1] The `iTransformer` model obtains many of its attributes from the `TslibBaseModel` class, which is a base class + where a lot of the boiler plate code for metadata handling and model initialization is implemented. + """ # noqa: E501 + + @classmethod + def _pkg(cls): + """Package containing the model.""" + from pytorch_forecasting.models.itransformer._itransformer_pkg_v2 import ( + iTransformer_pkg_v2, + ) + + return iTransformer_pkg_v2 + + def __init__( + self, + loss: nn.Module, + output_attention: bool = False, + factor: int = 5, + d_model: int = 512, + d_ff: int = 2048, + activation: str = "relu", + dropout: float = 0.1, + n_heads: int = 8, + e_layers: int = 3, + logging_metrics: Optional[list[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[dict] = None, + metadata: Optional[dict] = None, + **kwargs, + ): + super().__init__( + loss=loss, + logging_metrics=logging_metrics, + optimizer=optimizer, + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + lr_scheduler_params=lr_scheduler_params, + metadata=metadata, + ) + + self.output_attention = output_attention + self.factor = factor + self.d_model = d_model + self.d_ff = d_ff + self.activation = activation + self.dropout = dropout + self.n_heads = n_heads + self.e_layers = e_layers + self.freq = self.metadata.get("freq", "h") + + self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"]) + + self._init_network() + + def _init_network(self): + """ + Initialize the network for iTransformer's architecture. + """ + from pytorch_forecasting.layers import ( + AttentionLayer, + DataEmbedding_inverted, + Encoder, + EncoderLayer, + FullAttention, + ) + + self.enc_embedding = DataEmbedding_inverted( + self.context_length, self.d_model, self.dropout + ) + + self.n_quantiles = None + + if hasattr(self.loss, "quantiles") and self.loss.quantiles is not None: + self.n_quantiles = len(self.loss.quantiles) + + self.encoder = Encoder( + [ + EncoderLayer( + self_attention=AttentionLayer( + FullAttention( + False, + self.factor, + attention_dropout=self.dropout, + output_attention=self.output_attention, + ), + self.d_model, + self.n_heads, + ), + d_model=self.d_model, + d_ff=self.d_ff, + dropout=self.dropout, + activation=self.activation, + ) + for _ in range(self.e_layers) + ], + norm_layer=torch.nn.LayerNorm(self.d_model), + ) + if self.n_quantiles is not None: + self.projector = nn.Linear( + self.d_model, self.prediction_length * self.n_quantiles, bias=True + ) + else: + self.projector = nn.Linear(self.d_model, self.prediction_length, bias=True) + + def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Forward pass of the iTransformer model. + Args: + x (dict[str, torch.Tensor]): Input data. + Returns: + dict[str, torch.Tensor]: Model predictions. + """ + x_enc = x["history_target"] + x_mark_enc = x["history_cont"] + + _, _, N = x_enc.shape # B L N + # Embedding + # B L N -> B N E + enc_out = self.enc_embedding(x_enc, x_mark_enc) # covariates (e.g timestamp) + # B N E -> B N E + # the dimensions of embedded time series has been inverted + enc_out, attns = self.encoder(enc_out, x_mask=None) + + # B N E -> B N S -> B S N + dec_out = self.projector(enc_out).permute(0, 2, 1)[ + :, :, :N + ] # filter covariates + if self.output_attention: + return dec_out, attns + return dec_out + + def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Forward pass of the iTransformer model. + Args: + x (dict[str, torch.Tensor]): Input data. + Returns: + dict[str, torch.Tensor]: Model predictions. + """ + dec_out, attns = self._forecast(x) + + if self.n_quantiles is not None: + batch_size = dec_out.shape[0] + dec_out = dec_out.reshape( + batch_size, self.prediction_length, self.n_quantiles + ) + + prediction = dec_out[:, -self.prediction_length :, :] + + if "target_scale" in x: + prediction = self.transform_output(prediction, x["target_scale"]) + + if self.output_attention: + return {"prediction": prediction, "attention": attns} + return {"prediction": prediction}