Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions pytorch_forecasting/layers/_encoders/_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,50 @@
Implementation of encoder layers from `nn.Module`.
"""

import math
from math import sqrt

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class Encoder(nn.Module):
"""
Encoder module for the TimeXer model.
Encoder module for Tslib models.
Args:
layers (list): List of encoder layers.
norm_layer (nn.Module, optional): Normalization layer. Defaults to None.
projection (nn.Module, optional): Projection layer. Defaults to None.
"""
output_attention (Boolean, optional): Whether to output attention weights. Defaults to False.
""" # noqa: E501

def __init__(self, layers, norm_layer=None, projection=None):
def __init__(
self, layers, norm_layer=None, projection=None, output_attention=False
):
super().__init__()
self.layers = nn.ModuleList(layers)
self.norm = norm_layer
self.projection = projection

def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
for layer in self.layers:
x = layer(
x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta
)
self.output_attention = output_attention

def forward(
self, x, cross=None, x_mask=None, cross_mask=None, tau=None, delta=None
):
if self.output_attention:
attns = []
for layer in self.layers:
x, attn = layer(
x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta
)
attns.append(attn)
else:
for layer in self.layers:
x = layer(
x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta
)

if self.norm is not None:
x = self.norm(x)

if self.projection is not None:
x = self.projection(x)

if self.output_attention:
return x, attns
return x
69 changes: 39 additions & 30 deletions pytorch_forecasting/layers/_encoders/_encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,34 @@
Implementation of EncoderLayer for encoder-decoder architectures from `nn.Module`.
"""

import math
from math import sqrt

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class EncoderLayer(nn.Module):
"""
Encoder layer for the TimeXer model.
Encoder layer for TsLib models.
Args:
self_attention (nn.Module): Self-attention mechanism.
cross_attention (nn.Module): Cross-attention mechanism.
cross_attention (nn.Module, optional): Cross-attention mechanism.
d_model (int): Dimension of the model.
d_ff (int, optional):
Dimension of the feedforward layer. Defaults to 4 * d_model.
dropout (float): Dropout rate. Defaults to 0.1.
activation (str): Activation function. Defaults to "relu".
"""
output_attention (Boolean, optional): Whether to output attention weights. Defaults to False.
""" # noqa: E501

def __init__(
self,
self_attention,
cross_attention,
d_model,
cross_attention=None,
d_model=512,
d_ff=None,
dropout=0.1,
activation="relu",
output_attention=False,
):
super().__init__()
d_ff = d_ff or 4 * d_model
Expand All @@ -40,34 +38,45 @@ def __init__(
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
if self.cross_attention is not None:
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
self.output_attention = output_attention

def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
B, L, D = cross.shape
x = x + self.dropout(
self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0]
)
x = self.norm1(x)

x_glb_ori = x[:, -1, :].unsqueeze(1)
x_glb = torch.reshape(x_glb_ori, (B, -1, D))
x_glb_attn = self.dropout(
self.cross_attention(
x_glb, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta
)[0]
)
x_glb_attn = torch.reshape(
x_glb_attn, (x_glb_attn.shape[0] * x_glb_attn.shape[1], x_glb_attn.shape[2])
).unsqueeze(1)
x_glb = x_glb_ori + x_glb_attn
x_glb = self.norm2(x_glb)
def forward(
self, x, cross=None, x_mask=None, cross_mask=None, tau=None, delta=None
):
if self.output_attention:
x, attn = self.self_attention(
x, x, x, attn_mask=x_mask, tau=tau, delta=None
)
else:
x = self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0]
x = x + self.dropout(x)
y = x = self.norm1(x)
if self.cross_attention is not None:
B, L, D = cross.shape
x_glb_ori = x[:, -1, :].unsqueeze(1)
x_glb = torch.reshape(x_glb_ori, (B, -1, D))
x_glb_attn = self.dropout(
self.cross_attention(
x_glb, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta
)[0]
)
x_glb_attn = torch.reshape(
x_glb_attn,
(x_glb_attn.shape[0] * x_glb_attn.shape[1], x_glb_attn.shape[2]),
).unsqueeze(1)
x_glb = x_glb_ori + x_glb_attn
x_glb = self.norm2(x_glb)

y = x = torch.cat([x[:, :-1, :], x_glb], dim=1)
y = x = torch.cat([x[:, :-1, :], x_glb], dim=1)

y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))

if self.output_attention:
return self.norm3(x + y), attn
return self.norm3(x + y)
13 changes: 13 additions & 0 deletions pytorch_forecasting/models/itransformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
iTransformer model for forecasting time series.
"""

from pytorch_forecasting.models.itransformer._itransformer_pkg_v2 import (
iTransformer_pkg_v2,
)
from pytorch_forecasting.models.itransformer._itransformer_v2 import iTransformer

__all__ = [
"iTransformer",
"iTransformer_pkg_v2",
]
133 changes: 133 additions & 0 deletions pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""iTransformer package container v2."""

from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2


class iTransformer_pkg_v2(_BasePtForecasterV2):
"""iTransformer metadata container."""

_tags = {
"info:name": "iTransformer",
"authors": ["JATAYU000"],
"capability:exogenous": True,
"capability:multivariate": True,
"capability:pred_int": True,
"capability:flexible_history_length": False,
"capability:cold_start": False,
}

@classmethod
def get_cls(cls):
"""Get model class."""
from pytorch_forecasting.models.itransformer._itransformer_v2 import (
iTransformer,
)

return iTransformer

@classmethod
def _get_test_datamodule_from(cls, trainer_kwargs):
"""Create test dataloaders from trainer_kwargs - following v1 pattern."""
from pytorch_forecasting.data._tslib_data_module import TslibDataModule
from pytorch_forecasting.tests._data_scenarios import (
data_with_covariates_v2,
make_datasets_v2,
)

data_with_covariates = data_with_covariates_v2()

data_loader_default_kwargs = dict(
target="target",
group_ids=["agency_encoded", "sku_encoded"],
add_relative_time_idx=True,
)

data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {})
data_loader_default_kwargs.update(data_loader_kwargs)

datasets_info = make_datasets_v2(
data_with_covariates, **data_loader_default_kwargs
)

training_dataset = datasets_info["training_dataset"]
validation_dataset = datasets_info["validation_dataset"]

context_length = data_loader_kwargs.get("context_length", 12)
prediction_length = data_loader_kwargs.get("prediction_length", 4)
batch_size = data_loader_kwargs.get("batch_size", 2)

train_datamodule = TslibDataModule(
time_series_dataset=training_dataset,
context_length=context_length,
prediction_length=prediction_length,
add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True),
batch_size=batch_size,
train_val_test_split=(0.8, 0.2, 0.0),
)

val_datamodule = TslibDataModule(
time_series_dataset=validation_dataset,
context_length=context_length,
prediction_length=prediction_length,
add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True),
batch_size=batch_size,
train_val_test_split=(0.0, 1.0, 0.0),
)

test_datamodule = TslibDataModule(
time_series_dataset=validation_dataset,
context_length=context_length,
prediction_length=prediction_length,
add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True),
batch_size=1,
train_val_test_split=(0.0, 0.0, 1.0),
)

train_datamodule.setup("fit")
val_datamodule.setup("fit")
test_datamodule.setup("test")

train_dataloader = train_datamodule.train_dataloader()
val_dataloader = val_datamodule.val_dataloader()
test_dataloader = test_datamodule.test_dataloader()

return {
"train": train_dataloader,
"val": val_dataloader,
"test": test_dataloader,
"data_module": train_datamodule,
}

@classmethod
def get_test_train_params(cls):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a few more test cases here?

"""Get test train params."""
from pytorch_forecasting.metrics import QuantileLoss

return [
{},
dict(d_model=16, n_heads=2, e_layers=2, d_ff=64),
dict(
d_model=32,
n_heads=4,
e_layers=3,
d_ff=128,
dropout=0.1,
data_loader_kwargs=dict(
batch_size=4, context_length=8, prediction_length=4
),
),
dict(
hidden_size=32,
n_heads=2,
e_layers=1,
d_ff=64,
factor=2,
activation="relu",
dropout=0.05,
data_loader_kwargs=dict(
context_length=16,
prediction_length=4,
),
loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]),
),
]
Loading
Loading