diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 34aa145e7..2b6c7274d 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -157,6 +157,8 @@ def __init__( self.categorical_indices = [] self.continuous_indices = [] self._metadata = None + self._target_normalizer_fitted = False + self._feature_scalers_fitted = False for idx, col in enumerate(self.time_series_metadata["cols"]["x"]): if self.time_series_metadata["col_type"].get(col) == "C": @@ -328,6 +330,11 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]: # TODO: add scalers, target normalizers etc. + # target is always made into 2D tensor before normalizing. + # helps in generalizing to all cases - single and multi target. + if target.ndim == 1: + target = target.unsqueeze(-1) + categorical = ( features[:, self.categorical_indices] if self.categorical_indices @@ -339,9 +346,54 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]: else torch.zeros((features.shape[0], 0)) ) + target_original = target.clone() + + if self._target_normalizer is not None and self._target_normalizer_fitted: + normalized_target = target.clone() + if isinstance(self._target_normalizer, list): + for i, normalizer in enumerate(self._target_normalizer): + normalized_target[:, i] = normalizer.transform(target[:, i]) + elif isinstance(self._target_normalizer, TorchNormalizer): + # single target with n_targets = 1 as the second dimension. + target = target.squeeze(-1) + normalized_target = self._target_normalizer.transform(target).unsqueeze( + -1 + ) # noqa: E501 + elif isinstance(self._target_normalizer, (StandardScaler, RobustScaler)): + target_np = target.detach().numpy() + target_np = self._target_normalizer.transform(target_np) + normalized_target = torch.tensor(target_np, dtype=torch.float32) + target = normalized_target + + # applying feature scalers. + if self._feature_scalers_fitted and self.continuous_indices: + normalized_cont = continuous.clone() + feature_names = [ + self.time_series_metadata["cols"]["x"][idx] + for idx in self.continuous_indices + ] + + for feat_idx, feat_name in enumerate(feature_names): + if feat_name in self._scalers: + scaler = self._scalers[feat_name] + feature_data = continuous[:, feat_idx] + + if isinstance(scaler, (TorchNormalizer, EncoderNormalizer)): + normalized_cont[:, feat_idx] = scaler.transform(feature_data) + elif isinstance(scaler, (StandardScaler, RobustScaler)): + feature_np = feature_data.numpy() + feature_np = scaler.transform( + feature_np.reshape(-1, 1) + ).reshape(-1) # noqa: E501 + normalized_cont[:, feat_idx] = torch.tensor( + feature_np, dtype=torch.float32 + ) # noqa: E501 + continuous = normalized_cont + return { "features": {"categorical": categorical, "continuous": continuous}, "target": target, + "target_original": target_original, "static": sample.get("st", None), "group": sample.get("group", torch.tensor([0])), "length": len(target), @@ -350,13 +402,105 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]: "cutoff_time": cutoff_time, } + def _fit_target_normalizer(self, train_indices): + """Fit scalers on the training data.""" + + if self._target_normalizer is None: + return + + all_targets = [] + for idx in train_indices: + sample = self.time_series_dataset[idx] + target = sample["y"] + if isinstance(target, torch.Tensor): + all_targets.append(target) + else: + all_targets.append(torch.tensor(target, dtype=torch.float32)) + + if not all_targets: + return + + all_targets = torch.cat(all_targets, dim=0) + + if isinstance(self._target_normalizer, TorchNormalizer): + # handle multiple targets (in case). + if all_targets.ndim > 1 and all_targets.shape[1] > 1: + self._target_normalizer = [ + TorchNormalizer() for _ in range(all_targets.shape[1]) + ] + for i, normalizer in enumerate(self._target_normalizer): + normalizer.fit(all_targets[:, i]) + else: + if all_targets.ndim > 1 and all_targets.shape[1] == 1: + all_targets = all_targets.squeeze(-1) + self._target_normalizer.fit(all_targets) + elif isinstance(self._target_normalizer, (StandardScaler, RobustScaler)): + all_targets_np = all_targets.detach().numpy() + if all_targets_np.ndim == 1: + all_targets_np = all_targets_np.reshape(-1, 1) + self._target_normalizer.fit(all_targets_np) + + self._target_normalizer_fitted = True + + def _fit_scalers(self, train_indices): + """Fit scalers on continuous features in the training data.""" + + if not self._scalers or not self.continuous_indices: + return + + features_to_scale = { + self.time_series_metadata["cols"]["x"][idx]: pos + for pos, idx in enumerate(self.continuous_indices) + } + + for feat_name, scaler in self._scalers.items(): + if feat_name not in features_to_scale: + continue + feat_idx = features_to_scale[feat_name] + feat_data = [] + + for idx in train_indices: + sample = self.time_series_dataset[idx] + feature_data = sample["x"][:, feat_idx] + + if not isinstance(feature_data, torch.Tensor): + feature_data = torch.tensor(feature_data, dtype=torch.float32) + + feat_data.append(feature_data) + feat_data = torch.cat(feat_data, dim=0) + + if isinstance(scaler, (TorchNormalizer, EncoderNormalizer)): + scaler.fit(feat_data) + elif isinstance(scaler, (StandardScaler, RobustScaler)): + feat_data_np = feat_data.detach().numpy() + scaler.fit(feat_data_np.reshape(-1, 1)) + self._feature_scalers_fitted = True + + def _preprocess_all_data(self, indices: torch.Tensor) -> dict[dict[str, Any]]: + """Preprocess all data samples for given indices. + + Parameters + ---------- + indices : torch.Tensor + Tensor of indices specifying which samples to preprocess. + + Returns + ------- + dict[int, dict[str, Any]] + A dictionary mapping series indices to dictionaries containing preprocessed + data for each sample. + """ + preprocessed_data = {} + for idx in indices: + series_idx = idx.item() + preprocessed_data[series_idx] = self._preprocess_data(series_idx) + return preprocessed_data + class _ProcessedEncoderDecoderDataset(Dataset): """PyTorch Dataset for processed encoder-decoder time series data. Parameters ---------- - dataset : TimeSeries - The base time series dataset that provides access to raw data and metadata. data_module : EncoderDecoderTimeSeriesDataModule The data module handling preprocessing and metadata configuration. windows : List[Tuple[int, int, int, int]] @@ -364,18 +508,20 @@ class _ProcessedEncoderDecoderDataset(Dataset): (series_idx, start_idx, enc_length, pred_length). add_relative_time_idx : bool, default=False Whether to include relative time indices. + preprocessed_data : Optional[dict[int, dict[str, Any]]], default=None + Preprocessed data for all time series indices on input dataset. """ def __init__( self, - dataset: TimeSeries, data_module: "EncoderDecoderTimeSeriesDataModule", windows: list[tuple[int, int, int, int]], + preprocessed_data: dict[int, dict[str, Any]], add_relative_time_idx: bool = False, ): - self.dataset = dataset self.data_module = data_module self.windows = windows + self.preprocessed_data = preprocessed_data self.add_relative_time_idx = add_relative_time_idx def __len__(self): @@ -437,14 +583,18 @@ def __getitem__(self, idx): is returned. Otherwise, a tensor of shape (pred_length,) is returned. """ series_idx, start_idx, enc_length, pred_length = self.windows[idx] - data = self.data_module._preprocess_data(series_idx) + data = self.preprocessed_data[series_idx] end_idx = start_idx + enc_length + pred_length encoder_indices = slice(start_idx, start_idx + enc_length) decoder_indices = slice(start_idx + enc_length, end_idx) target_past = data["target"][encoder_indices] - target_scale = target_past[~torch.isnan(target_past)].abs().mean() + + target_original_past = data["target_original"][encoder_indices] + target_scale = ( + target_original_past[~torch.isnan(target_original_past)].abs().mean() + ) # noqa: E501 if torch.isnan(target_scale) or target_scale == 0: target_scale = torch.tensor(1.0) @@ -561,11 +711,10 @@ def __getitem__(self, idx): y = data["target"][decoder_indices] - if self.data_module.n_targets > 1: - y = [t.squeeze(-1) for t in torch.split(y, 1, dim=1)] + if y.shape[-1] > 1: + y = [y[:, i] for i in range(y.shape[-1])] else: y = y.squeeze(-1) - return x, y def _create_windows(self, indices: torch.Tensor) -> list[tuple[int, int, int, int]]: @@ -648,39 +797,48 @@ def setup(self, stage: Optional[str] = None): self._test_indices = self._split_indices[self._train_size + self._val_size :] if stage is None or stage == "fit": + self._fit_target_normalizer(self._train_indices) + self._fit_scalers(self._train_indices) if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"): + self._train_preprocessed = self._preprocess_all_data( + self._train_indices + ) + self._val_preprocessed = self._preprocess_all_data(self._val_indices) + self.train_windows = self._create_windows(self._train_indices) self.val_windows = self._create_windows(self._val_indices) self.train_dataset = self._ProcessedEncoderDecoderDataset( - self.time_series_dataset, self, self.train_windows, + self._train_preprocessed, self.add_relative_time_idx, ) self.val_dataset = self._ProcessedEncoderDecoderDataset( - self.time_series_dataset, self, self.val_windows, + self._val_preprocessed, self.add_relative_time_idx, ) elif stage == "test": if not hasattr(self, "test_dataset"): + self._test_preprocessed = self._preprocess_all_data(self._test_indices) self.test_windows = self._create_windows(self._test_indices) self.test_dataset = self._ProcessedEncoderDecoderDataset( - self.time_series_dataset, self, self.test_windows, + self._test_preprocessed, self.add_relative_time_idx, ) elif stage == "predict": predict_indices = torch.arange(len(self.time_series_dataset)) + self._predict_preprocessed = self._preprocess_all_data(predict_indices) self.predict_windows = self._create_windows(predict_indices) self.predict_dataset = self._ProcessedEncoderDecoderDataset( - self.time_series_dataset, self, self.predict_windows, + self._predict_preprocessed, self.add_relative_time_idx, ) diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index 77c84c371..45a93e57c 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -1,8 +1,11 @@ import numpy as np import pandas as pd import pytest +from sklearn.preprocessing import RobustScaler, StandardScaler +import torch from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule +from pytorch_forecasting.data.encoders import TorchNormalizer from pytorch_forecasting.data.timeseries import TimeSeries @@ -464,9 +467,113 @@ def test_multivariate_target(): max_encoder_length=10, max_prediction_length=5, batch_size=4, + target_normalizer=TorchNormalizer(), ) dm.setup() x, y = dm.train_dataset[0] assert len(y) == 2 + + +@pytest.mark.parametrize( + "normalizer", + [ + "auto", + TorchNormalizer(), + StandardScaler(), + RobustScaler(), + ], +) +def test_target_normalizers(sample_timeseries_data, normalizer): + """Test different target normalizers. + + Ensures compatibility and correct integration of various normalizers. + Verifies that: + - The normalizer is applied correctly. + - Output shapes are as expected. + - Target is actually scaled. + """ + dm_no_norm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=4, + target_normalizer=None, + ) + dm_no_norm.setup(stage="fit") + + dm_with_norm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=4, + target_normalizer=normalizer, + ) + dm_with_norm.setup(stage="fit") + + x_no_norm, y_no_norm = dm_no_norm.train_dataset[0] + x_with_norm, y_with_norm = dm_with_norm.train_dataset[0] + assert y_with_norm.shape == y_no_norm.shape + assert x_with_norm["target_past"].shape == x_no_norm["target_past"].shape + + assert not torch.allclose(y_with_norm, y_no_norm), "Target should be normalized" + assert not torch.allclose( + x_with_norm["target_past"], x_no_norm["target_past"] + ), "target_past should be normalized" + + assert y_with_norm.var() < y_no_norm.var(), "Normalization should reduce variance" + + +@pytest.mark.parametrize( + "scaler_type", + [ + TorchNormalizer, + StandardScaler, + RobustScaler, + ], +) +def test_feature_scaling(sample_timeseries_data, scaler_type): + """Test feature scaling with different scalers. + + Verifies that: + - Scaling is actually applied (data changes) + - Only specified features are scaled + - Output format is preserved + """ + scalers = { + "cont_feat1": scaler_type(), + "cont_feat2": scaler_type(), + } + + dm_no_scale = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=4, + scalers=None, + ) + dm_no_scale.setup(stage="fit") + + dm_with_scale = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=4, + scalers=scalers, + ) + dm_with_scale.setup(stage="fit") + + x_no_scale, _ = dm_no_scale.train_dataset[0] + x_with_scale, _ = dm_with_scale.train_dataset[0] + + assert x_with_scale["encoder_cont"].shape == x_no_scale["encoder_cont"].shape + assert x_with_scale["decoder_cont"].shape == x_no_scale["decoder_cont"].shape + + assert not torch.allclose( + x_with_scale["encoder_cont"], x_no_scale["encoder_cont"] + ), "Continuous features should be scaled" + + assert ( + x_with_scale["encoder_cont"].var() < x_no_scale["encoder_cont"].var() + ), "Scaling should reduce variance"