Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 125 additions & 16 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Base model for all PVNet submodels"""

import copy
import logging
import os
import shutil
import time
from importlib.metadata import version
from pathlib import Path
from typing import Optional

import hydra
import torch
import torch.nn.functional as F
import yaml
from huggingface_hub import ModelCard, ModelCardData, snapshot_download
from huggingface_hub.hf_api import HfApi
Expand All @@ -26,7 +29,9 @@
)


def fill_config_paths_with_placeholder(config: dict, placeholder: str = "PLACEHOLDER") -> dict:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please unsplit

def fill_config_paths_with_placeholder(
config: dict, placeholder: str = "PLACEHOLDER"
) -> dict:
"""Modify the config in place to fill data paths with placeholder strings.

Args:
Expand Down Expand Up @@ -75,14 +80,16 @@ def minimize_config_for_model(config: dict, model: "BaseModel") -> dict:
del input_config["nwp"][nwp_source]
else:
# Replace the image size
nwp_pixel_size = model.nwp_encoders_dict[nwp_source].image_size_pixels
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please unsplit

nwp_pixel_size = model.nwp_encoders_dict[
nwp_source
].image_size_pixels
nwp_config["image_size_pixels_height"] = nwp_pixel_size
nwp_config["image_size_pixels_width"] = nwp_pixel_size

# Replace the interval_end_minutes minutes
nwp_config["interval_end_minutes"] = (
nwp_config["interval_start_minutes"] +
(model.nwp_encoders_dict[nwp_source].sequence_length - 1)
nwp_config["interval_start_minutes"]
+ (model.nwp_encoders_dict[nwp_source].sequence_length - 1)
* nwp_config["time_resolution_minutes"]
)

Expand All @@ -99,8 +106,8 @@ def minimize_config_for_model(config: dict, model: "BaseModel") -> dict:

# Replace the interval_end_minutes minutes
sat_config["interval_end_minutes"] = (
sat_config["interval_start_minutes"] +
(model.sat_encoder.sequence_length - 1)
sat_config["interval_start_minutes"]
+ (model.sat_encoder.sequence_length - 1)
* sat_config["time_resolution_minutes"]
)

Expand Down Expand Up @@ -158,7 +165,7 @@ def download_from_hf(
return [f"{save_dir}/{f}" for f in filename]
else:
return f"{save_dir}/{filename}"

except Exception as e:
if attempt == max_retries:
raise Exception(
Expand Down Expand Up @@ -290,12 +297,14 @@ def save_pretrained(
# Save the model config and data config
if isinstance(model_config, dict):
with open(save_directory / MODEL_CONFIG_NAME, "w") as outfile:
yaml.dump(model_config, outfile, sort_keys=False, default_flow_style=False)
yaml.dump(
model_config, outfile, sort_keys=False, default_flow_style=False
)

# Save cleaned version of input data configuration file
with open(data_config_path) as cfg:
config = yaml.load(cfg, Loader=yaml.FullLoader)

config = fill_config_paths_with_placeholder(config)
config = minimize_config_for_model(config, self)

Expand All @@ -304,13 +313,17 @@ def save_pretrained(

# Save the datamodule config
if datamodule_config_path is not None:
shutil.copyfile(datamodule_config_path, save_directory / DATAMODULE_CONFIG_NAME)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please unsplit


shutil.copyfile(
datamodule_config_path, save_directory / DATAMODULE_CONFIG_NAME
)

# Save the full experimental config
if experiment_config_path is not None:
shutil.copyfile(experiment_config_path, save_directory / FULL_CONFIG_NAME)

card = self.create_hugging_face_model_card(card_template_path, wandb_repo, wandb_ids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please unsplit

card = self.create_hugging_face_model_card(
card_template_path, wandb_repo, wandb_ids
)

(save_directory / MODEL_CARD_NAME).write_text(str(card))

Expand Down Expand Up @@ -370,8 +383,9 @@ def create_hugging_face_model_card(

# Find package versions for OCF packages
packages_to_display = ["pvnet", "ocf-data-sampler"]
packages_and_versions = {package: version(package) for package in packages_to_display}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please unsplit


packages_and_versions = {
package: version(package) for package in packages_to_display
}

package_versions_markdown = ""
for package, v in packages_and_versions.items():
Expand All @@ -392,7 +406,8 @@ def __init__(
self,
history_minutes: int,
forecast_minutes: int,
output_quantiles: list[float] | None = None,
output_quantiles: Optional[list[float]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change the type hint back here and match in the line below?

param: type | None = None is the current best practice for python > 3.10 rather than param: Optional[type] = None which was the practice in python<3.9

num_gmm_components: Optional[int] = None,
target_key: str = "gsp",
interval_minutes: int = 30,
):
Expand All @@ -403,6 +418,8 @@ def __init__(
forecast_minutes (int): Length of the GSP forecast period in minutes
output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
None the output is a single value.
num_gmm_components: Number of Gaussian Mixture Model components to use for the model.
If None, output quantiles must be set. If both None, the output is a single value.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If None, output quantiles must be set. If both None, the output is a single value.
If None, output quantiles must be set. If both None, the output is a single value.

target_key: The key of the target variable in the batch
interval_minutes: The interval in minutes between each timestep in the data
"""
Expand All @@ -413,6 +430,7 @@ def __init__(
self.history_minutes = history_minutes
self.forecast_minutes = forecast_minutes
self.output_quantiles = output_quantiles
self.num_gmm_components = num_gmm_components
self.interval_minutes = interval_minutes

# Number of timestemps for 30 minutely data
Expand All @@ -422,12 +440,24 @@ def __init__(
# Store whether the model should use quantile regression or simply predict the mean
self.use_quantile_regression = self.output_quantiles is not None

self.use_gmm = self.num_gmm_components is not None

# Both quantile regression and GMM cannot be used at the same time
if self.use_quantile_regression and self.use_gmm:
raise ValueError(
"Cannot use quantile regression and GMM at the same time. "
"Please set either output_quantiles or num_gmm_components to None."
)

# Store the number of ouput features that the model should predict for
if self.use_quantile_regression:
self.num_output_features = self.forecast_len * len(self.output_quantiles)
elif self.use_gmm:
self.num_output_features = self.forecast_len * self.num_gmm_components * 3
else:
self.num_output_features = self.forecast_len


def _adapt_batch(self, batch: TensorBatch) -> TensorBatch:
"""Slice batches into appropriate shapes for model.

Expand Down Expand Up @@ -491,6 +521,33 @@ def _adapt_batch(self, batch: TensorBatch) -> TensorBatch:

return new_batch

def _parse_gmm_params(self, y_gmm):
"""
Reshape flat output into (μ, σ, π) tensors.

y_gmm: (batch, forecast_len * num_components * 3)

Returns:
mus: (batch, forecast_len, num_components)
sigmas: (batch, forecast_len, num_components)
pis: (batch, forecast_len, num_components)
"""
bsz = y_gmm.shape[0]
# reshape to [batch, forecast_len, num_components, 3]
params = y_gmm.view(
bsz,
self.forecast_len,
self.num_gmm_components,
3,
)
mus = params[..., 0]
# enforce positivity & stability
sigmas = F.softplus(params[..., 1]) + 1e-3
# softmax over components to get mixture weights
logits = params[..., 2]
pis = F.softmax(logits, dim=-1)
return mus, sigmas, pis

def _quantiles_to_prediction(self, y_quantiles: torch.Tensor) -> torch.Tensor:
"""
Convert network prediction into a point prediction.
Expand All @@ -508,4 +565,56 @@ def _quantiles_to_prediction(self, y_quantiles: torch.Tensor) -> torch.Tensor:
"""
# y_quantiles Shape: batch_size, seq_length, num_quantiles
idx = self.output_quantiles.index(0.5)
return y_quantiles[..., idx]
return y_quantiles[..., idx]

def _gmm_to_prediction(self, y_gmm):
"""
Return the **mixture mean** E[Y] = Σ_k π_k μ_k as a point forecast.

Note!
• This is the mean of the Gaussian mixture, not the median.
Quantile/MAE training targets the median.
• Potential add in: If a median point forecast is desired, it can be obtained by solving
F(x) = Σ_k π_k Φ((x - μ_k)/σ_k) = 0.5 per horizon step (e.g., bisection),
or approximated via sampling.
"""
mus, sigmas, pis = self._parse_gmm_params(y_gmm)
# expectation over components
y_pred = (pis * mus).sum(dim=-1) # Here we sum over the components
# y_pred shape: (batch, forecast_len)
return y_pred

def _sample_from_gmm(self, mus, sigmas, pis, n_samples=20):
"""
Sample from Gaussian Mixture Model for each timestep and batch.

Args:
mus: [batch, horizon, components]
sigmas: [batch, horizon, components]
pis: [batch, horizon, components]
n_samples: Number of samples to draw

Returns:
samples: [n_samples, batch, horizon]
"""
batch, horizon, num_components = mus.shape

# Sample component indices according to mixture weights (pis)
categorical = torch.distributions.Categorical(pis)
component_indices = categorical.sample(
(n_samples,)
) # [n_samples, batch, horizon]

# Gather the corresponding μ and σ for each sampled index
mus_samples = mus.unsqueeze(0).expand(n_samples, -1, -1, -1)
sigmas_samples = sigmas.unsqueeze(0).expand(n_samples, -1, -1, -1)

idx = component_indices.unsqueeze(-1)
gathered_mus = torch.gather(mus_samples, dim=3, index=idx).squeeze(-1)
gathered_sigmas = torch.gather(sigmas_samples, dim=3, index=idx).squeeze(-1)

# Sample from the normal distribution
normal = torch.distributions.Normal(gathered_mus, gathered_sigmas)
samples = normal.sample() # [n_samples, batch, horizon]

return samples
35 changes: 23 additions & 12 deletions pvnet/models/late_fusion/late_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
self,
output_network: AbstractLinearNetwork,
output_quantiles: list[float] | None = None,
num_gmm_components: int | None = None,
nwp_encoders_dict: dict[str, AbstractNWPSatelliteEncoder] | None = None,
sat_encoder: AbstractNWPSatelliteEncoder | None = None,
pv_encoder: AbstractSitesEncoder | None = None,
Expand Down Expand Up @@ -77,6 +78,9 @@ def __init__(
features to produce the forecast.
output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
None the output is a single value.
num_gmm_components: If set to an integer, the model will predict parameters for a
Gaussian mixture model with this many components. Mutually exclusive with
output_quantiles.
nwp_encoders_dict: A dictionary of partially instantiated pytorch Module class used to
encode the NWP data from 4D into a 1D feature vector from different sources.
sat_encoder: A partially instantiated pytorch Module class used to encode the satellite
Expand Down Expand Up @@ -118,6 +122,7 @@ def __init__(
history_minutes=history_minutes,
forecast_minutes=forecast_minutes,
output_quantiles=output_quantiles,
num_gmm_components=num_gmm_components,
target_key=target_key,
interval_minutes=interval_minutes,
)
Expand Down Expand Up @@ -171,7 +176,7 @@ def __init__(
)
if add_image_embedding_channel:
self.sat_embed = ImageEmbedding(
num_embeddings, self.sat_sequence_len, self.sat_encoder.image_size_pixels
num_embeddings, self.sat_sequence_len, self.sat_encoder.image_size_pixels,
)

# Update num features
Expand Down Expand Up @@ -230,15 +235,16 @@ def __init__(
fusion_input_features += self.pv_encoder.out_features

if self.use_id_embedding:
self.embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
self.embed = nn.Embedding(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please unsplit

num_embeddings=num_embeddings, embedding_dim=embedding_dim
)

# Update num features
fusion_input_features += embedding_dim

if self.include_sun:
self.sun_fc1 = nn.Linear(
in_features=2
* (self.forecast_len + self.history_len + 1),
in_features=2 * (self.forecast_len + self.history_len + 1),
out_features=16,
)

Expand All @@ -247,8 +253,7 @@ def __init__(

if self.include_time:
self.time_fc1 = nn.Linear(
in_features=4
* (self.forecast_len + self.history_len + 1),
in_features=4 * (self.forecast_len + self.history_len + 1),
out_features=32,
)

Expand All @@ -268,7 +273,6 @@ def __init__(
out_features=self.num_output_features,
)


def forward(self, x: TensorBatch) -> torch.Tensor:
"""Run model forward"""

Expand All @@ -278,7 +282,10 @@ def forward(self, x: TensorBatch) -> torch.Tensor:
if self.use_id_embedding:
# eg: x['gsp_id'] = [1] with location_id_mapping = {1:0}, would give [0]
id = torch.tensor(
[self.location_id_mapping[i.item()] for i in x[f"{self._target_key}_id"]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please unsplit

[
self.location_id_mapping[i.item()]
for i in x[f"{self._target_key}_id"]
],
device=x[f"{self._target_key}_id"].device,
dtype=torch.int64,
)
Expand All @@ -288,7 +295,9 @@ def forward(self, x: TensorBatch) -> torch.Tensor:
if self.include_sat:
# Shape: batch_size, seq_length, channel, height, width
sat_data = x["satellite_actual"][:, : self.sat_sequence_len]
sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please unsplit

sat_data = torch.swapaxes(
sat_data, 1, 2
).float() # switch time and channels

if self.add_image_embedding_channel:
sat_data = self.sat_embed(sat_data, id)
Expand Down Expand Up @@ -343,7 +352,7 @@ def forward(self, x: TensorBatch) -> torch.Tensor:
sun = torch.cat((x["solar_azimuth"], x["solar_elevation"]), dim=1).float()
sun = self.sun_fc1(sun)
modes["sun"] = sun

if self.include_time:
time = [x[k] for k in ["date_sin", "date_cos", "time_sin", "time_cos"]]
time = torch.cat(time, dim=1).float()
Expand All @@ -353,7 +362,9 @@ def forward(self, x: TensorBatch) -> torch.Tensor:
out = self.output_network(modes)

if self.use_quantile_regression:
# Shape: batch_size, seq_length * num_quantiles
out = out.reshape(out.shape[0], self.forecast_len, len(self.output_quantiles))
out = out.view(out.size(0), self.forecast_len, len(self.output_quantiles))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solid refinement!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:)


# no further reshape needed if gmm is used: BaseModel._parse_gmm_params will view it as
# (batch, forecast_len, num_components, 3)

return out
Loading
Loading