-
-
Notifications
You must be signed in to change notification settings - Fork 38
GSOC: Probabilistic Machine Learning for Solar Forecasting: Applying Gaussian Mixture Models for output #448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f50a665
ea90672
c344676
2a1bc84
c26dae1
d2736c2
d7dabca
93c0b64
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,14 +1,17 @@ | ||||||
"""Base model for all PVNet submodels""" | ||||||
|
||||||
import copy | ||||||
import logging | ||||||
import os | ||||||
import shutil | ||||||
import time | ||||||
from importlib.metadata import version | ||||||
from pathlib import Path | ||||||
from typing import Optional | ||||||
|
||||||
import hydra | ||||||
import torch | ||||||
import torch.nn.functional as F | ||||||
import yaml | ||||||
from huggingface_hub import ModelCard, ModelCardData, snapshot_download | ||||||
from huggingface_hub.hf_api import HfApi | ||||||
|
@@ -26,7 +29,9 @@ | |||||
) | ||||||
|
||||||
|
||||||
def fill_config_paths_with_placeholder(config: dict, placeholder: str = "PLACEHOLDER") -> dict: | ||||||
def fill_config_paths_with_placeholder( | ||||||
config: dict, placeholder: str = "PLACEHOLDER" | ||||||
) -> dict: | ||||||
"""Modify the config in place to fill data paths with placeholder strings. | ||||||
|
||||||
Args: | ||||||
|
@@ -75,14 +80,16 @@ def minimize_config_for_model(config: dict, model: "BaseModel") -> dict: | |||||
del input_config["nwp"][nwp_source] | ||||||
else: | ||||||
# Replace the image size | ||||||
nwp_pixel_size = model.nwp_encoders_dict[nwp_source].image_size_pixels | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please unsplit |
||||||
nwp_pixel_size = model.nwp_encoders_dict[ | ||||||
nwp_source | ||||||
].image_size_pixels | ||||||
nwp_config["image_size_pixels_height"] = nwp_pixel_size | ||||||
nwp_config["image_size_pixels_width"] = nwp_pixel_size | ||||||
|
||||||
# Replace the interval_end_minutes minutes | ||||||
nwp_config["interval_end_minutes"] = ( | ||||||
nwp_config["interval_start_minutes"] + | ||||||
(model.nwp_encoders_dict[nwp_source].sequence_length - 1) | ||||||
nwp_config["interval_start_minutes"] | ||||||
+ (model.nwp_encoders_dict[nwp_source].sequence_length - 1) | ||||||
* nwp_config["time_resolution_minutes"] | ||||||
) | ||||||
|
||||||
|
@@ -99,8 +106,8 @@ def minimize_config_for_model(config: dict, model: "BaseModel") -> dict: | |||||
|
||||||
# Replace the interval_end_minutes minutes | ||||||
sat_config["interval_end_minutes"] = ( | ||||||
sat_config["interval_start_minutes"] + | ||||||
(model.sat_encoder.sequence_length - 1) | ||||||
sat_config["interval_start_minutes"] | ||||||
+ (model.sat_encoder.sequence_length - 1) | ||||||
* sat_config["time_resolution_minutes"] | ||||||
) | ||||||
|
||||||
|
@@ -158,7 +165,7 @@ def download_from_hf( | |||||
return [f"{save_dir}/{f}" for f in filename] | ||||||
else: | ||||||
return f"{save_dir}/{filename}" | ||||||
|
||||||
except Exception as e: | ||||||
if attempt == max_retries: | ||||||
raise Exception( | ||||||
|
@@ -290,12 +297,14 @@ def save_pretrained( | |||||
# Save the model config and data config | ||||||
if isinstance(model_config, dict): | ||||||
with open(save_directory / MODEL_CONFIG_NAME, "w") as outfile: | ||||||
yaml.dump(model_config, outfile, sort_keys=False, default_flow_style=False) | ||||||
yaml.dump( | ||||||
model_config, outfile, sort_keys=False, default_flow_style=False | ||||||
) | ||||||
|
||||||
# Save cleaned version of input data configuration file | ||||||
with open(data_config_path) as cfg: | ||||||
config = yaml.load(cfg, Loader=yaml.FullLoader) | ||||||
|
||||||
config = fill_config_paths_with_placeholder(config) | ||||||
config = minimize_config_for_model(config, self) | ||||||
|
||||||
|
@@ -304,13 +313,17 @@ def save_pretrained( | |||||
|
||||||
# Save the datamodule config | ||||||
if datamodule_config_path is not None: | ||||||
shutil.copyfile(datamodule_config_path, save_directory / DATAMODULE_CONFIG_NAME) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please unsplit |
||||||
|
||||||
shutil.copyfile( | ||||||
datamodule_config_path, save_directory / DATAMODULE_CONFIG_NAME | ||||||
) | ||||||
|
||||||
# Save the full experimental config | ||||||
if experiment_config_path is not None: | ||||||
shutil.copyfile(experiment_config_path, save_directory / FULL_CONFIG_NAME) | ||||||
|
||||||
card = self.create_hugging_face_model_card(card_template_path, wandb_repo, wandb_ids) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please unsplit |
||||||
card = self.create_hugging_face_model_card( | ||||||
card_template_path, wandb_repo, wandb_ids | ||||||
) | ||||||
|
||||||
(save_directory / MODEL_CARD_NAME).write_text(str(card)) | ||||||
|
||||||
|
@@ -370,8 +383,9 @@ def create_hugging_face_model_card( | |||||
|
||||||
# Find package versions for OCF packages | ||||||
packages_to_display = ["pvnet", "ocf-data-sampler"] | ||||||
packages_and_versions = {package: version(package) for package in packages_to_display} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please unsplit |
||||||
|
||||||
packages_and_versions = { | ||||||
package: version(package) for package in packages_to_display | ||||||
} | ||||||
|
||||||
package_versions_markdown = "" | ||||||
for package, v in packages_and_versions.items(): | ||||||
|
@@ -392,7 +406,8 @@ def __init__( | |||||
self, | ||||||
history_minutes: int, | ||||||
forecast_minutes: int, | ||||||
output_quantiles: list[float] | None = None, | ||||||
output_quantiles: Optional[list[float]] = None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you change the type hint back here and match in the line below?
|
||||||
num_gmm_components: Optional[int] = None, | ||||||
target_key: str = "gsp", | ||||||
interval_minutes: int = 30, | ||||||
): | ||||||
|
@@ -403,6 +418,8 @@ def __init__( | |||||
forecast_minutes (int): Length of the GSP forecast period in minutes | ||||||
output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to | ||||||
None the output is a single value. | ||||||
num_gmm_components: Number of Gaussian Mixture Model components to use for the model. | ||||||
If None, output quantiles must be set. If both None, the output is a single value. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
target_key: The key of the target variable in the batch | ||||||
interval_minutes: The interval in minutes between each timestep in the data | ||||||
""" | ||||||
|
@@ -413,6 +430,7 @@ def __init__( | |||||
self.history_minutes = history_minutes | ||||||
self.forecast_minutes = forecast_minutes | ||||||
self.output_quantiles = output_quantiles | ||||||
self.num_gmm_components = num_gmm_components | ||||||
self.interval_minutes = interval_minutes | ||||||
|
||||||
# Number of timestemps for 30 minutely data | ||||||
|
@@ -422,12 +440,24 @@ def __init__( | |||||
# Store whether the model should use quantile regression or simply predict the mean | ||||||
self.use_quantile_regression = self.output_quantiles is not None | ||||||
|
||||||
self.use_gmm = self.num_gmm_components is not None | ||||||
|
||||||
# Both quantile regression and GMM cannot be used at the same time | ||||||
if self.use_quantile_regression and self.use_gmm: | ||||||
raise ValueError( | ||||||
"Cannot use quantile regression and GMM at the same time. " | ||||||
"Please set either output_quantiles or num_gmm_components to None." | ||||||
) | ||||||
|
||||||
# Store the number of ouput features that the model should predict for | ||||||
if self.use_quantile_regression: | ||||||
self.num_output_features = self.forecast_len * len(self.output_quantiles) | ||||||
elif self.use_gmm: | ||||||
self.num_output_features = self.forecast_len * self.num_gmm_components * 3 | ||||||
else: | ||||||
self.num_output_features = self.forecast_len | ||||||
|
||||||
|
||||||
def _adapt_batch(self, batch: TensorBatch) -> TensorBatch: | ||||||
"""Slice batches into appropriate shapes for model. | ||||||
|
||||||
|
@@ -491,6 +521,33 @@ def _adapt_batch(self, batch: TensorBatch) -> TensorBatch: | |||||
|
||||||
return new_batch | ||||||
|
||||||
def _parse_gmm_params(self, y_gmm): | ||||||
""" | ||||||
Reshape flat output into (μ, σ, π) tensors. | ||||||
|
||||||
y_gmm: (batch, forecast_len * num_components * 3) | ||||||
|
||||||
Returns: | ||||||
mus: (batch, forecast_len, num_components) | ||||||
sigmas: (batch, forecast_len, num_components) | ||||||
pis: (batch, forecast_len, num_components) | ||||||
""" | ||||||
bsz = y_gmm.shape[0] | ||||||
# reshape to [batch, forecast_len, num_components, 3] | ||||||
params = y_gmm.view( | ||||||
bsz, | ||||||
self.forecast_len, | ||||||
self.num_gmm_components, | ||||||
3, | ||||||
) | ||||||
mus = params[..., 0] | ||||||
# enforce positivity & stability | ||||||
sigmas = F.softplus(params[..., 1]) + 1e-3 | ||||||
dfulu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# softmax over components to get mixture weights | ||||||
logits = params[..., 2] | ||||||
pis = F.softmax(logits, dim=-1) | ||||||
return mus, sigmas, pis | ||||||
|
||||||
def _quantiles_to_prediction(self, y_quantiles: torch.Tensor) -> torch.Tensor: | ||||||
""" | ||||||
Convert network prediction into a point prediction. | ||||||
|
@@ -508,4 +565,56 @@ def _quantiles_to_prediction(self, y_quantiles: torch.Tensor) -> torch.Tensor: | |||||
""" | ||||||
# y_quantiles Shape: batch_size, seq_length, num_quantiles | ||||||
idx = self.output_quantiles.index(0.5) | ||||||
return y_quantiles[..., idx] | ||||||
return y_quantiles[..., idx] | ||||||
|
||||||
def _gmm_to_prediction(self, y_gmm): | ||||||
dfulu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
""" | ||||||
Return the **mixture mean** E[Y] = Σ_k π_k μ_k as a point forecast. | ||||||
|
||||||
Note! | ||||||
• This is the mean of the Gaussian mixture, not the median. | ||||||
Quantile/MAE training targets the median. | ||||||
• Potential add in: If a median point forecast is desired, it can be obtained by solving | ||||||
F(x) = Σ_k π_k Φ((x - μ_k)/σ_k) = 0.5 per horizon step (e.g., bisection), | ||||||
or approximated via sampling. | ||||||
""" | ||||||
mus, sigmas, pis = self._parse_gmm_params(y_gmm) | ||||||
# expectation over components | ||||||
y_pred = (pis * mus).sum(dim=-1) # Here we sum over the components | ||||||
# y_pred shape: (batch, forecast_len) | ||||||
return y_pred | ||||||
|
||||||
def _sample_from_gmm(self, mus, sigmas, pis, n_samples=20): | ||||||
""" | ||||||
Sample from Gaussian Mixture Model for each timestep and batch. | ||||||
|
||||||
Args: | ||||||
mus: [batch, horizon, components] | ||||||
sigmas: [batch, horizon, components] | ||||||
pis: [batch, horizon, components] | ||||||
n_samples: Number of samples to draw | ||||||
|
||||||
Returns: | ||||||
samples: [n_samples, batch, horizon] | ||||||
""" | ||||||
batch, horizon, num_components = mus.shape | ||||||
|
||||||
# Sample component indices according to mixture weights (pis) | ||||||
categorical = torch.distributions.Categorical(pis) | ||||||
component_indices = categorical.sample( | ||||||
(n_samples,) | ||||||
) # [n_samples, batch, horizon] | ||||||
|
||||||
# Gather the corresponding μ and σ for each sampled index | ||||||
mus_samples = mus.unsqueeze(0).expand(n_samples, -1, -1, -1) | ||||||
sigmas_samples = sigmas.unsqueeze(0).expand(n_samples, -1, -1, -1) | ||||||
|
||||||
idx = component_indices.unsqueeze(-1) | ||||||
gathered_mus = torch.gather(mus_samples, dim=3, index=idx).squeeze(-1) | ||||||
gathered_sigmas = torch.gather(sigmas_samples, dim=3, index=idx).squeeze(-1) | ||||||
|
||||||
# Sample from the normal distribution | ||||||
normal = torch.distributions.Normal(gathered_mus, gathered_sigmas) | ||||||
samples = normal.sample() # [n_samples, batch, horizon] | ||||||
|
||||||
return samples |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,7 @@ def __init__( | |
self, | ||
output_network: AbstractLinearNetwork, | ||
output_quantiles: list[float] | None = None, | ||
num_gmm_components: int | None = None, | ||
nwp_encoders_dict: dict[str, AbstractNWPSatelliteEncoder] | None = None, | ||
sat_encoder: AbstractNWPSatelliteEncoder | None = None, | ||
pv_encoder: AbstractSitesEncoder | None = None, | ||
|
@@ -77,6 +78,9 @@ def __init__( | |
features to produce the forecast. | ||
output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to | ||
None the output is a single value. | ||
num_gmm_components: If set to an integer, the model will predict parameters for a | ||
Gaussian mixture model with this many components. Mutually exclusive with | ||
output_quantiles. | ||
nwp_encoders_dict: A dictionary of partially instantiated pytorch Module class used to | ||
encode the NWP data from 4D into a 1D feature vector from different sources. | ||
sat_encoder: A partially instantiated pytorch Module class used to encode the satellite | ||
|
@@ -118,6 +122,7 @@ def __init__( | |
history_minutes=history_minutes, | ||
forecast_minutes=forecast_minutes, | ||
output_quantiles=output_quantiles, | ||
num_gmm_components=num_gmm_components, | ||
target_key=target_key, | ||
interval_minutes=interval_minutes, | ||
) | ||
|
@@ -171,7 +176,7 @@ def __init__( | |
) | ||
if add_image_embedding_channel: | ||
self.sat_embed = ImageEmbedding( | ||
num_embeddings, self.sat_sequence_len, self.sat_encoder.image_size_pixels | ||
num_embeddings, self.sat_sequence_len, self.sat_encoder.image_size_pixels, | ||
) | ||
|
||
# Update num features | ||
|
@@ -230,15 +235,16 @@ def __init__( | |
fusion_input_features += self.pv_encoder.out_features | ||
|
||
if self.use_id_embedding: | ||
self.embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) | ||
self.embed = nn.Embedding( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please unsplit |
||
num_embeddings=num_embeddings, embedding_dim=embedding_dim | ||
) | ||
|
||
# Update num features | ||
fusion_input_features += embedding_dim | ||
|
||
if self.include_sun: | ||
self.sun_fc1 = nn.Linear( | ||
in_features=2 | ||
* (self.forecast_len + self.history_len + 1), | ||
in_features=2 * (self.forecast_len + self.history_len + 1), | ||
out_features=16, | ||
) | ||
|
||
|
@@ -247,8 +253,7 @@ def __init__( | |
|
||
if self.include_time: | ||
self.time_fc1 = nn.Linear( | ||
in_features=4 | ||
* (self.forecast_len + self.history_len + 1), | ||
in_features=4 * (self.forecast_len + self.history_len + 1), | ||
out_features=32, | ||
) | ||
|
||
|
@@ -268,7 +273,6 @@ def __init__( | |
out_features=self.num_output_features, | ||
) | ||
|
||
|
||
def forward(self, x: TensorBatch) -> torch.Tensor: | ||
"""Run model forward""" | ||
|
||
|
@@ -278,7 +282,10 @@ def forward(self, x: TensorBatch) -> torch.Tensor: | |
if self.use_id_embedding: | ||
# eg: x['gsp_id'] = [1] with location_id_mapping = {1:0}, would give [0] | ||
id = torch.tensor( | ||
[self.location_id_mapping[i.item()] for i in x[f"{self._target_key}_id"]], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please unsplit |
||
[ | ||
self.location_id_mapping[i.item()] | ||
for i in x[f"{self._target_key}_id"] | ||
], | ||
device=x[f"{self._target_key}_id"].device, | ||
dtype=torch.int64, | ||
) | ||
|
@@ -288,7 +295,9 @@ def forward(self, x: TensorBatch) -> torch.Tensor: | |
if self.include_sat: | ||
# Shape: batch_size, seq_length, channel, height, width | ||
sat_data = x["satellite_actual"][:, : self.sat_sequence_len] | ||
sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please unsplit |
||
sat_data = torch.swapaxes( | ||
sat_data, 1, 2 | ||
).float() # switch time and channels | ||
|
||
if self.add_image_embedding_channel: | ||
sat_data = self.sat_embed(sat_data, id) | ||
|
@@ -343,7 +352,7 @@ def forward(self, x: TensorBatch) -> torch.Tensor: | |
sun = torch.cat((x["solar_azimuth"], x["solar_elevation"]), dim=1).float() | ||
sun = self.sun_fc1(sun) | ||
modes["sun"] = sun | ||
|
||
if self.include_time: | ||
time = [x[k] for k in ["date_sin", "date_cos", "time_sin", "time_cos"]] | ||
time = torch.cat(time, dim=1).float() | ||
|
@@ -353,7 +362,9 @@ def forward(self, x: TensorBatch) -> torch.Tensor: | |
out = self.output_network(modes) | ||
|
||
if self.use_quantile_regression: | ||
# Shape: batch_size, seq_length * num_quantiles | ||
out = out.reshape(out.shape[0], self.forecast_len, len(self.output_quantiles)) | ||
out = out.view(out.size(0), self.forecast_len, len(self.output_quantiles)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Solid refinement! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. :) |
||
|
||
# no further reshape needed if gmm is used: BaseModel._parse_gmm_params will view it as | ||
# (batch, forecast_len, num_components, 3) | ||
|
||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please unsplit