Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
49c3ab4
add feature scaling to d2
PranavBhatP Oct 12, 2025
1942383
Merge branch 'main' into feature-scaling
PranavBhatP Oct 12, 2025
403e0f2
fix incorrect orig_idx index
PranavBhatP Oct 12, 2025
5661be4
fix incorrect attibute
PranavBhatP Oct 12, 2025
38cefe4
handle unfitted scalers
PranavBhatP Oct 16, 2025
f242290
change accelerator to cpu in v2 notebook cell 10
PranavBhatP Oct 17, 2025
54da1c4
use torch.from_numpy instead of torch.tensor for numpy to torch conve…
PranavBhatP Oct 18, 2025
c145e9b
revert accelerator mode to auto from cpu for example notebook trainin…
PranavBhatP Oct 18, 2025
ec4cf03
potential fix for issue in trainingof v2
PranavBhatP Oct 21, 2025
18f2b2a
replace MAE() with nn.L1Loss() to fix notebook test failures
PranavBhatP Nov 2, 2025
5c99959
Merge branch 'main' into feature-scaling
PranavBhatP Nov 2, 2025
85ba7cb
Merge branch 'main' into feature-scaling
PranavBhatP Nov 4, 2025
d96aed5
Merge branch 'main' into feature-scaling
PranavBhatP Nov 25, 2025
fd8411a
revert notebook state
PranavBhatP Dec 5, 2025
0830090
Merge branch 'main' into feature-scaling
PranavBhatP Dec 5, 2025
ff42a1b
some changes to data module - incomplete
PranavBhatP Dec 7, 2025
f86f9a5
fix scaling and target norm - working
PranavBhatP Dec 8, 2025
728cfad
remove target_scale and add target_normalizer instead
PranavBhatP Dec 8, 2025
091d0f8
restore original notebook
PranavBhatP Dec 8, 2025
6d38331
revert breaking change on target scale
PranavBhatP Dec 8, 2025
4ff3444
Merge branch 'main' into feature-scaling
PranavBhatP Dec 15, 2025
ca5cb97
separate concerns for feature scaling and target normalizers inside _…
PranavBhatP Dec 15, 2025
77dc43f
fix multi target handling during normalization
PranavBhatP Dec 15, 2025
757fe83
fix data module output format
PranavBhatP Dec 15, 2025
bceb0e3
add tests for feature scaling and norm
PranavBhatP Dec 15, 2025
c07a343
remove unecessary dataset param from internal D2 dataset class
PranavBhatP Dec 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 171 additions & 13 deletions pytorch_forecasting/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def __init__(
self.categorical_indices = []
self.continuous_indices = []
self._metadata = None
self._target_normalizer_fitted = False
self._feature_scalers_fitted = False

for idx, col in enumerate(self.time_series_metadata["cols"]["x"]):
if self.time_series_metadata["col_type"].get(col) == "C":
Expand Down Expand Up @@ -328,6 +330,11 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]:

# TODO: add scalers, target normalizers etc.

# target is always made into 2D tensor before normalizing.
# helps in generalizing to all cases - single and multi target.
if target.ndim == 1:
target = target.unsqueeze(-1)

categorical = (
features[:, self.categorical_indices]
if self.categorical_indices
Expand All @@ -339,9 +346,54 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]:
else torch.zeros((features.shape[0], 0))
)

target_original = target.clone()

if self._target_normalizer is not None and self._target_normalizer_fitted:
normalized_target = target.clone()
if isinstance(self._target_normalizer, list):
for i, normalizer in enumerate(self._target_normalizer):
normalized_target[:, i] = normalizer.transform(target[:, i])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here if the list is TorchNormalizer and StandardScaler, would it not throw an error as we are not detaching the tensor?

Copy link
Contributor Author

@PranavBhatP PranavBhatP Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list is guaranteed to contain TorchNormalizer since we need a list of normalizers only when the TorchNormalizer is specified as the target normalizer when we have multiple targets and the list is set during fitting like below. StandardScaler works normally by taking an additional dimension in the tensor for n_targets and converting it to numpy format and fitting on each target automatically.

def _fit_target_normalizer(self, train_indices):
        ....
        if isinstance(self._target_normalizer, TorchNormalizer):
            # handle multiple targets (in case).
            if all_targets.ndim > 1 and all_targets.shape[1] > 1:
                self._target_normalizer = [
                    TorchNormalizer() for _ in range(all_targets.shape[1])
                ]
                for i, normalizer in enumerate(self._target_normalizer):
                    normalizer.fit(all_targets[:, i])
            else:
                if all_targets.ndim > 1 and all_targets.shape[1] == 1:
                    all_targets = all_targets.squeeze(-1)
                self._target_normalizer.fit(all_targets)
        elif isinstance(self._target_normalizer, (StandardScaler, RobustScaler)):
            all_targets_np = all_targets.detach().numpy()
            if all_targets_np.ndim == 1:
                all_targets_np = all_targets_np.reshape(-1, 1)
            self._target_normalizer.fit(all_targets_np)

        self._target_normalizer_fitted = True

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks it is a bit clear now. But I think for multi-target we had "list" of tensors right? Are we creating a n+1D tensor somewhere for Sklearn scalers? If not, this may lead to failure?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also start testing multi-target options now so that we can find any failure in that case

elif isinstance(self._target_normalizer, TorchNormalizer):
# single target with n_targets = 1 as the second dimension.
target = target.squeeze(-1)
normalized_target = self._target_normalizer.transform(target).unsqueeze(
-1
) # noqa: E501
elif isinstance(self._target_normalizer, (StandardScaler, RobustScaler)):
target_np = target.detach().numpy()
target_np = self._target_normalizer.transform(target_np)
normalized_target = torch.tensor(target_np, dtype=torch.float32)
target = normalized_target

# applying feature scalers.
if self._feature_scalers_fitted and self.continuous_indices:
normalized_cont = continuous.clone()
feature_names = [
self.time_series_metadata["cols"]["x"][idx]
for idx in self.continuous_indices
]

for feat_idx, feat_name in enumerate(feature_names):
if feat_name in self._scalers:
scaler = self._scalers[feat_name]
feature_data = continuous[:, feat_idx]

if isinstance(scaler, (TorchNormalizer, EncoderNormalizer)):
normalized_cont[:, feat_idx] = scaler.transform(feature_data)
elif isinstance(scaler, (StandardScaler, RobustScaler)):
feature_np = feature_data.numpy()
feature_np = scaler.transform(
feature_np.reshape(-1, 1)
).reshape(-1) # noqa: E501
normalized_cont[:, feat_idx] = torch.tensor(
feature_np, dtype=torch.float32
) # noqa: E501
continuous = normalized_cont

return {
"features": {"categorical": categorical, "continuous": continuous},
"target": target,
"target_original": target_original,
"static": sample.get("st", None),
"group": sample.get("group", torch.tensor([0])),
"length": len(target),
Expand All @@ -350,32 +402,126 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]:
"cutoff_time": cutoff_time,
}

def _fit_target_normalizer(self, train_indices):
"""Fit scalers on the training data."""

if self._target_normalizer is None:
return

all_targets = []
for idx in train_indices:
sample = self.time_series_dataset[idx]
target = sample["y"]
if isinstance(target, torch.Tensor):
all_targets.append(target)
else:
all_targets.append(torch.tensor(target, dtype=torch.float32))

if not all_targets:
return

all_targets = torch.cat(all_targets, dim=0)

if isinstance(self._target_normalizer, TorchNormalizer):
# handle multiple targets (in case).
if all_targets.ndim > 1 and all_targets.shape[1] > 1:
self._target_normalizer = [
TorchNormalizer() for _ in range(all_targets.shape[1])
]
for i, normalizer in enumerate(self._target_normalizer):
normalizer.fit(all_targets[:, i])
else:
if all_targets.ndim > 1 and all_targets.shape[1] == 1:
all_targets = all_targets.squeeze(-1)
self._target_normalizer.fit(all_targets)
elif isinstance(self._target_normalizer, (StandardScaler, RobustScaler)):
all_targets_np = all_targets.detach().numpy()
if all_targets_np.ndim == 1:
all_targets_np = all_targets_np.reshape(-1, 1)
self._target_normalizer.fit(all_targets_np)

self._target_normalizer_fitted = True

def _fit_scalers(self, train_indices):
"""Fit scalers on continuous features in the training data."""

if not self._scalers or not self.continuous_indices:
return

features_to_scale = {
self.time_series_metadata["cols"]["x"][idx]: pos
for pos, idx in enumerate(self.continuous_indices)
}

for feat_name, scaler in self._scalers.items():
if feat_name not in features_to_scale:
continue
feat_idx = features_to_scale[feat_name]
feat_data = []

for idx in train_indices:
sample = self.time_series_dataset[idx]
feature_data = sample["x"][:, feat_idx]

if not isinstance(feature_data, torch.Tensor):
feature_data = torch.tensor(feature_data, dtype=torch.float32)

feat_data.append(feature_data)
feat_data = torch.cat(feat_data, dim=0)

if isinstance(scaler, (TorchNormalizer, EncoderNormalizer)):
scaler.fit(feat_data)
elif isinstance(scaler, (StandardScaler, RobustScaler)):
feat_data_np = feat_data.detach().numpy()
scaler.fit(feat_data_np.reshape(-1, 1))
self._feature_scalers_fitted = True

def _preprocess_all_data(self, indices: torch.Tensor) -> dict[dict[str, Any]]:
"""Preprocess all data samples for given indices.

Parameters
----------
indices : torch.Tensor
Tensor of indices specifying which samples to preprocess.

Returns
-------
dict[int, dict[str, Any]]
A dictionary mapping series indices to dictionaries containing preprocessed
data for each sample.
"""
preprocessed_data = {}
for idx in indices:
series_idx = idx.item()
preprocessed_data[series_idx] = self._preprocess_data(series_idx)
return preprocessed_data

class _ProcessedEncoderDecoderDataset(Dataset):
"""PyTorch Dataset for processed encoder-decoder time series data.

Parameters
----------
dataset : TimeSeries
The base time series dataset that provides access to raw data and metadata.
data_module : EncoderDecoderTimeSeriesDataModule
The data module handling preprocessing and metadata configuration.
windows : List[Tuple[int, int, int, int]]
List of window tuples containing
(series_idx, start_idx, enc_length, pred_length).
add_relative_time_idx : bool, default=False
Whether to include relative time indices.
preprocessed_data : Optional[dict[int, dict[str, Any]]], default=None
Preprocessed data for all time series indices on input dataset.
"""

def __init__(
self,
dataset: TimeSeries,
data_module: "EncoderDecoderTimeSeriesDataModule",
windows: list[tuple[int, int, int, int]],
preprocessed_data: dict[int, dict[str, Any]],
add_relative_time_idx: bool = False,
):
self.dataset = dataset
self.data_module = data_module
self.windows = windows
self.preprocessed_data = preprocessed_data
self.add_relative_time_idx = add_relative_time_idx

def __len__(self):
Expand Down Expand Up @@ -437,14 +583,18 @@ def __getitem__(self, idx):
is returned. Otherwise, a tensor of shape (pred_length,) is returned.
"""
series_idx, start_idx, enc_length, pred_length = self.windows[idx]
data = self.data_module._preprocess_data(series_idx)
data = self.preprocessed_data[series_idx]

end_idx = start_idx + enc_length + pred_length
encoder_indices = slice(start_idx, start_idx + enc_length)
decoder_indices = slice(start_idx + enc_length, end_idx)

target_past = data["target"][encoder_indices]
target_scale = target_past[~torch.isnan(target_past)].abs().mean()

target_original_past = data["target_original"][encoder_indices]
target_scale = (
target_original_past[~torch.isnan(target_original_past)].abs().mean()
) # noqa: E501
if torch.isnan(target_scale) or target_scale == 0:
target_scale = torch.tensor(1.0)

Expand Down Expand Up @@ -561,11 +711,10 @@ def __getitem__(self, idx):

y = data["target"][decoder_indices]

if self.data_module.n_targets > 1:
y = [t.squeeze(-1) for t in torch.split(y, 1, dim=1)]
if y.shape[-1] > 1:
y = [y[:, i] for i in range(y.shape[-1])]
else:
y = y.squeeze(-1)

return x, y

def _create_windows(self, indices: torch.Tensor) -> list[tuple[int, int, int, int]]:
Expand Down Expand Up @@ -648,39 +797,48 @@ def setup(self, stage: Optional[str] = None):
self._test_indices = self._split_indices[self._train_size + self._val_size :]

if stage is None or stage == "fit":
self._fit_target_normalizer(self._train_indices)
self._fit_scalers(self._train_indices)
if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"):
self._train_preprocessed = self._preprocess_all_data(
self._train_indices
)
self._val_preprocessed = self._preprocess_all_data(self._val_indices)

self.train_windows = self._create_windows(self._train_indices)
self.val_windows = self._create_windows(self._val_indices)

self.train_dataset = self._ProcessedEncoderDecoderDataset(
self.time_series_dataset,
self,
self.train_windows,
self._train_preprocessed,
self.add_relative_time_idx,
)
self.val_dataset = self._ProcessedEncoderDecoderDataset(
self.time_series_dataset,
self,
self.val_windows,
self._val_preprocessed,
self.add_relative_time_idx,
)

elif stage == "test":
if not hasattr(self, "test_dataset"):
self._test_preprocessed = self._preprocess_all_data(self._test_indices)
self.test_windows = self._create_windows(self._test_indices)
self.test_dataset = self._ProcessedEncoderDecoderDataset(
self.time_series_dataset,
self,
self.test_windows,
self._test_preprocessed,
self.add_relative_time_idx,
)
elif stage == "predict":
predict_indices = torch.arange(len(self.time_series_dataset))
self._predict_preprocessed = self._preprocess_all_data(predict_indices)
self.predict_windows = self._create_windows(predict_indices)
self.predict_dataset = self._ProcessedEncoderDecoderDataset(
self.time_series_dataset,
self,
self.predict_windows,
self._predict_preprocessed,
self.add_relative_time_idx,
)

Expand Down
Loading
Loading