diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 36815d4b..869558eb 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -1,4 +1,5 @@ """Base model for all PVNet submodels""" + import copy import logging import os @@ -6,9 +7,11 @@ import time from importlib.metadata import version from pathlib import Path +from typing import Optional import hydra import torch +import torch.nn.functional as F import yaml from huggingface_hub import ModelCard, ModelCardData, snapshot_download from huggingface_hub.hf_api import HfApi @@ -26,7 +29,9 @@ ) -def fill_config_paths_with_placeholder(config: dict, placeholder: str = "PLACEHOLDER") -> dict: +def fill_config_paths_with_placeholder( + config: dict, placeholder: str = "PLACEHOLDER" +) -> dict: """Modify the config in place to fill data paths with placeholder strings. Args: @@ -75,14 +80,16 @@ def minimize_config_for_model(config: dict, model: "BaseModel") -> dict: del input_config["nwp"][nwp_source] else: # Replace the image size - nwp_pixel_size = model.nwp_encoders_dict[nwp_source].image_size_pixels + nwp_pixel_size = model.nwp_encoders_dict[ + nwp_source + ].image_size_pixels nwp_config["image_size_pixels_height"] = nwp_pixel_size nwp_config["image_size_pixels_width"] = nwp_pixel_size # Replace the interval_end_minutes minutes nwp_config["interval_end_minutes"] = ( - nwp_config["interval_start_minutes"] + - (model.nwp_encoders_dict[nwp_source].sequence_length - 1) + nwp_config["interval_start_minutes"] + + (model.nwp_encoders_dict[nwp_source].sequence_length - 1) * nwp_config["time_resolution_minutes"] ) @@ -99,8 +106,8 @@ def minimize_config_for_model(config: dict, model: "BaseModel") -> dict: # Replace the interval_end_minutes minutes sat_config["interval_end_minutes"] = ( - sat_config["interval_start_minutes"] + - (model.sat_encoder.sequence_length - 1) + sat_config["interval_start_minutes"] + + (model.sat_encoder.sequence_length - 1) * sat_config["time_resolution_minutes"] ) @@ -158,7 +165,7 @@ def download_from_hf( return [f"{save_dir}/{f}" for f in filename] else: return f"{save_dir}/{filename}" - + except Exception as e: if attempt == max_retries: raise Exception( @@ -290,12 +297,14 @@ def save_pretrained( # Save the model config and data config if isinstance(model_config, dict): with open(save_directory / MODEL_CONFIG_NAME, "w") as outfile: - yaml.dump(model_config, outfile, sort_keys=False, default_flow_style=False) + yaml.dump( + model_config, outfile, sort_keys=False, default_flow_style=False + ) # Save cleaned version of input data configuration file with open(data_config_path) as cfg: config = yaml.load(cfg, Loader=yaml.FullLoader) - + config = fill_config_paths_with_placeholder(config) config = minimize_config_for_model(config, self) @@ -304,13 +313,17 @@ def save_pretrained( # Save the datamodule config if datamodule_config_path is not None: - shutil.copyfile(datamodule_config_path, save_directory / DATAMODULE_CONFIG_NAME) - + shutil.copyfile( + datamodule_config_path, save_directory / DATAMODULE_CONFIG_NAME + ) + # Save the full experimental config if experiment_config_path is not None: shutil.copyfile(experiment_config_path, save_directory / FULL_CONFIG_NAME) - card = self.create_hugging_face_model_card(card_template_path, wandb_repo, wandb_ids) + card = self.create_hugging_face_model_card( + card_template_path, wandb_repo, wandb_ids + ) (save_directory / MODEL_CARD_NAME).write_text(str(card)) @@ -370,8 +383,9 @@ def create_hugging_face_model_card( # Find package versions for OCF packages packages_to_display = ["pvnet", "ocf-data-sampler"] - packages_and_versions = {package: version(package) for package in packages_to_display} - + packages_and_versions = { + package: version(package) for package in packages_to_display + } package_versions_markdown = "" for package, v in packages_and_versions.items(): @@ -392,7 +406,8 @@ def __init__( self, history_minutes: int, forecast_minutes: int, - output_quantiles: list[float] | None = None, + output_quantiles: Optional[list[float]] = None, + num_gmm_components: Optional[int] = None, target_key: str = "gsp", interval_minutes: int = 30, ): @@ -403,6 +418,8 @@ def __init__( forecast_minutes (int): Length of the GSP forecast period in minutes output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to None the output is a single value. + num_gmm_components: Number of Gaussian Mixture Model components to use for the model. + If None, output quantiles must be set. If both None, the output is a single value. target_key: The key of the target variable in the batch interval_minutes: The interval in minutes between each timestep in the data """ @@ -413,6 +430,7 @@ def __init__( self.history_minutes = history_minutes self.forecast_minutes = forecast_minutes self.output_quantiles = output_quantiles + self.num_gmm_components = num_gmm_components self.interval_minutes = interval_minutes # Number of timestemps for 30 minutely data @@ -422,12 +440,24 @@ def __init__( # Store whether the model should use quantile regression or simply predict the mean self.use_quantile_regression = self.output_quantiles is not None + self.use_gmm = self.num_gmm_components is not None + + # Both quantile regression and GMM cannot be used at the same time + if self.use_quantile_regression and self.use_gmm: + raise ValueError( + "Cannot use quantile regression and GMM at the same time. " + "Please set either output_quantiles or num_gmm_components to None." + ) + # Store the number of ouput features that the model should predict for if self.use_quantile_regression: self.num_output_features = self.forecast_len * len(self.output_quantiles) + elif self.use_gmm: + self.num_output_features = self.forecast_len * self.num_gmm_components * 3 else: self.num_output_features = self.forecast_len + def _adapt_batch(self, batch: TensorBatch) -> TensorBatch: """Slice batches into appropriate shapes for model. @@ -491,6 +521,33 @@ def _adapt_batch(self, batch: TensorBatch) -> TensorBatch: return new_batch + def _parse_gmm_params(self, y_gmm): + """ + Reshape flat output into (μ, σ, π) tensors. + + y_gmm: (batch, forecast_len * num_components * 3) + + Returns: + mus: (batch, forecast_len, num_components) + sigmas: (batch, forecast_len, num_components) + pis: (batch, forecast_len, num_components) + """ + bsz = y_gmm.shape[0] + # reshape to [batch, forecast_len, num_components, 3] + params = y_gmm.view( + bsz, + self.forecast_len, + self.num_gmm_components, + 3, + ) + mus = params[..., 0] + # enforce positivity & stability + sigmas = F.softplus(params[..., 1]) + 1e-3 + # softmax over components to get mixture weights + logits = params[..., 2] + pis = F.softmax(logits, dim=-1) + return mus, sigmas, pis + def _quantiles_to_prediction(self, y_quantiles: torch.Tensor) -> torch.Tensor: """ Convert network prediction into a point prediction. @@ -508,4 +565,56 @@ def _quantiles_to_prediction(self, y_quantiles: torch.Tensor) -> torch.Tensor: """ # y_quantiles Shape: batch_size, seq_length, num_quantiles idx = self.output_quantiles.index(0.5) - return y_quantiles[..., idx] \ No newline at end of file + return y_quantiles[..., idx] + + def _gmm_to_prediction(self, y_gmm): + """ + Return the **mixture mean** E[Y] = Σ_k π_k μ_k as a point forecast. + + Note! + • This is the mean of the Gaussian mixture, not the median. + Quantile/MAE training targets the median. + • Potential add in: If a median point forecast is desired, it can be obtained by solving + F(x) = Σ_k π_k Φ((x - μ_k)/σ_k) = 0.5 per horizon step (e.g., bisection), + or approximated via sampling. + """ + mus, sigmas, pis = self._parse_gmm_params(y_gmm) + # expectation over components + y_pred = (pis * mus).sum(dim=-1) # Here we sum over the components + # y_pred shape: (batch, forecast_len) + return y_pred + + def _sample_from_gmm(self, mus, sigmas, pis, n_samples=20): + """ + Sample from Gaussian Mixture Model for each timestep and batch. + + Args: + mus: [batch, horizon, components] + sigmas: [batch, horizon, components] + pis: [batch, horizon, components] + n_samples: Number of samples to draw + + Returns: + samples: [n_samples, batch, horizon] + """ + batch, horizon, num_components = mus.shape + + # Sample component indices according to mixture weights (pis) + categorical = torch.distributions.Categorical(pis) + component_indices = categorical.sample( + (n_samples,) + ) # [n_samples, batch, horizon] + + # Gather the corresponding μ and σ for each sampled index + mus_samples = mus.unsqueeze(0).expand(n_samples, -1, -1, -1) + sigmas_samples = sigmas.unsqueeze(0).expand(n_samples, -1, -1, -1) + + idx = component_indices.unsqueeze(-1) + gathered_mus = torch.gather(mus_samples, dim=3, index=idx).squeeze(-1) + gathered_sigmas = torch.gather(sigmas_samples, dim=3, index=idx).squeeze(-1) + + # Sample from the normal distribution + normal = torch.distributions.Normal(gathered_mus, gathered_sigmas) + samples = normal.sample() # [n_samples, batch, horizon] + + return samples diff --git a/pvnet/models/late_fusion/late_fusion.py b/pvnet/models/late_fusion/late_fusion.py index 3d835d64..49fdd003 100644 --- a/pvnet/models/late_fusion/late_fusion.py +++ b/pvnet/models/late_fusion/late_fusion.py @@ -39,6 +39,7 @@ def __init__( self, output_network: AbstractLinearNetwork, output_quantiles: list[float] | None = None, + num_gmm_components: int | None = None, nwp_encoders_dict: dict[str, AbstractNWPSatelliteEncoder] | None = None, sat_encoder: AbstractNWPSatelliteEncoder | None = None, pv_encoder: AbstractSitesEncoder | None = None, @@ -77,6 +78,9 @@ def __init__( features to produce the forecast. output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to None the output is a single value. + num_gmm_components: If set to an integer, the model will predict parameters for a + Gaussian mixture model with this many components. Mutually exclusive with + output_quantiles. nwp_encoders_dict: A dictionary of partially instantiated pytorch Module class used to encode the NWP data from 4D into a 1D feature vector from different sources. sat_encoder: A partially instantiated pytorch Module class used to encode the satellite @@ -118,6 +122,7 @@ def __init__( history_minutes=history_minutes, forecast_minutes=forecast_minutes, output_quantiles=output_quantiles, + num_gmm_components=num_gmm_components, target_key=target_key, interval_minutes=interval_minutes, ) @@ -171,7 +176,7 @@ def __init__( ) if add_image_embedding_channel: self.sat_embed = ImageEmbedding( - num_embeddings, self.sat_sequence_len, self.sat_encoder.image_size_pixels + num_embeddings, self.sat_sequence_len, self.sat_encoder.image_size_pixels, ) # Update num features @@ -230,15 +235,16 @@ def __init__( fusion_input_features += self.pv_encoder.out_features if self.use_id_embedding: - self.embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + self.embed = nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim + ) # Update num features fusion_input_features += embedding_dim if self.include_sun: self.sun_fc1 = nn.Linear( - in_features=2 - * (self.forecast_len + self.history_len + 1), + in_features=2 * (self.forecast_len + self.history_len + 1), out_features=16, ) @@ -247,8 +253,7 @@ def __init__( if self.include_time: self.time_fc1 = nn.Linear( - in_features=4 - * (self.forecast_len + self.history_len + 1), + in_features=4 * (self.forecast_len + self.history_len + 1), out_features=32, ) @@ -268,7 +273,6 @@ def __init__( out_features=self.num_output_features, ) - def forward(self, x: TensorBatch) -> torch.Tensor: """Run model forward""" @@ -278,7 +282,10 @@ def forward(self, x: TensorBatch) -> torch.Tensor: if self.use_id_embedding: # eg: x['gsp_id'] = [1] with location_id_mapping = {1:0}, would give [0] id = torch.tensor( - [self.location_id_mapping[i.item()] for i in x[f"{self._target_key}_id"]], + [ + self.location_id_mapping[i.item()] + for i in x[f"{self._target_key}_id"] + ], device=x[f"{self._target_key}_id"].device, dtype=torch.int64, ) @@ -288,7 +295,9 @@ def forward(self, x: TensorBatch) -> torch.Tensor: if self.include_sat: # Shape: batch_size, seq_length, channel, height, width sat_data = x["satellite_actual"][:, : self.sat_sequence_len] - sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels + sat_data = torch.swapaxes( + sat_data, 1, 2 + ).float() # switch time and channels if self.add_image_embedding_channel: sat_data = self.sat_embed(sat_data, id) @@ -343,7 +352,7 @@ def forward(self, x: TensorBatch) -> torch.Tensor: sun = torch.cat((x["solar_azimuth"], x["solar_elevation"]), dim=1).float() sun = self.sun_fc1(sun) modes["sun"] = sun - + if self.include_time: time = [x[k] for k in ["date_sin", "date_cos", "time_sin", "time_cos"]] time = torch.cat(time, dim=1).float() @@ -353,7 +362,9 @@ def forward(self, x: TensorBatch) -> torch.Tensor: out = self.output_network(modes) if self.use_quantile_regression: - # Shape: batch_size, seq_length * num_quantiles - out = out.reshape(out.shape[0], self.forecast_len, len(self.output_quantiles)) + out = out.view(out.size(0), self.forecast_len, len(self.output_quantiles)) + + # no further reshape needed if gmm is used: BaseModel._parse_gmm_params will view it as + # (batch, forecast_len, num_components, 3) return out diff --git a/pvnet/training/lightning_module.py b/pvnet/training/lightning_module.py index 80b9efbe..34eac545 100644 --- a/pvnet/training/lightning_module.py +++ b/pvnet/training/lightning_module.py @@ -10,6 +10,8 @@ import xarray as xr from ocf_data_sampler.numpy_sample.common_types import TensorBatch from ocf_data_sampler.torch_datasets.sample.base import copy_batch_to_device +from torch.distributions import Normal +from torchmetrics.functional.regression import continuous_ranked_probability_score as crps_fn from pvnet.data.base_datamodule import collate_fn from pvnet.models.base_model import BaseModel @@ -46,15 +48,17 @@ def __init__( self.save_all_validation_results = save_all_validation_results def transfer_batch_to_device( - self, - batch: TensorBatch, - device: torch.device, + self, + batch: TensorBatch, + device: torch.device, dataloader_idx: int, ) -> dict: """Method to move custom batches to a given device""" return copy_batch_to_device(batch, device) - def _calculate_quantile_loss(self, y_quantiles: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def _calculate_quantile_loss( + self, y_quantiles: torch.Tensor, y: torch.Tensor + ) -> torch.Tensor: """Calculate quantile loss. Note: @@ -76,7 +80,29 @@ def _calculate_quantile_loss(self, y_quantiles: torch.Tensor, y: torch.Tensor) - losses = 2 * torch.cat(losses, dim=2) return losses.mean() - + + def _calculate_nll(self, y_gmm, y_true): + """ + Negative log-likelihood of y_true under the predicted GMM. + + Args: + y_gmm: (batch, forecast_len * num_components * 3) + y_true: (batch, forecast_len) + """ + mus, sigmas, pis = self.model._parse_gmm_params(y_gmm) + # expand y_true to [batch, forecast_len, num_components] + y_exp = y_true.unsqueeze(-1).expand_as(mus) + # compute component log-probs + comp = Normal(mus, sigmas) + log_p = comp.log_prob(y_exp) # [batch, forecast_len, num_components] + # weight them + weighted = log_p + torch.log(pis + 1e-12) + # log-sum-exp over components + log_probs = torch.logsumexp(weighted, dim=-1) # [batch, forecast_len] + # negative log-likelihood + nll = -log_probs.mean() # mean over batch & horizon + return nll + def configure_optimizers(self): """Configure the optimizers using learning rate found with LR finder if used""" if self.lr is not None: @@ -85,7 +111,7 @@ def configure_optimizers(self): return self._optimizer(self.model) def _calculate_common_losses( - self, + self, y: torch.Tensor, y_hat: torch.Tensor, ) -> dict[str, torch.Tensor]: @@ -95,12 +121,14 @@ def _calculate_common_losses( if self.model.use_quantile_regression: losses["quantile_loss"] = self._calculate_quantile_loss(y_hat, y) - y_hat = self.model._quantiles_to_prediction(y_hat) + elif self.model.use_gmm: + losses["nll"] = self._calculate_nll(y_hat, y) + y_hat = self.model._gmm_to_prediction(y_hat) - losses.update({"MSE": F.mse_loss(y_hat, y), "MAE": F.l1_loss(y_hat, y)}) + losses.update({"MSE": F.mse_loss(y_hat, y), "MAE": F.l1_loss(y_hat, y)}) return losses - + def training_step(self, batch: TensorBatch, batch_idx: int) -> torch.Tensor: """Run training step""" y_hat = self.model(batch) @@ -117,13 +145,15 @@ def training_step(self, batch: TensorBatch, batch_idx: int) -> torch.Tensor: if self.model.use_quantile_regression: opt_target = losses["quantile_loss/train"] + elif self.model.use_gmm: + opt_target = losses["nll/train"] else: opt_target = losses["MAE/train"] return opt_target - + def _calculate_val_losses( - self, - y: torch.Tensor, + self, + y: torch.Tensor, y_hat: torch.Tensor, ) -> dict[str, torch.Tensor]: """Calculate additional losses only run in validation""" @@ -139,50 +169,117 @@ def _calculate_val_losses( mask = y >= 0.01 losses[metric_name.format(quantile)] = below_quant[mask].float().mean() + b, h, q = y_hat.shape + # crps_fn expects preds with last dim = ensemble members + losses["CRPS"] = crps_fn(preds=y_hat.reshape(b * h, q), target=y.reshape(-1)) + + if self.model.use_gmm: + # Convert GMM into samples or quantiles + mus, sigmas, pis = self.model._parse_gmm_params(y_hat) # shape: [B, H, C] + + # Sample from GMM to get an ensemble of predictions + num_samples = 20 + samples = self.model._sample_from_gmm( + mus, sigmas, pis, n_samples=num_samples + ) + # samples: [num_samples, batch, forecast_len] + + # reshape for TorchMetrics: [batch * forecast_len, ensemble_members] + ensemble = samples.permute(1, 2, 0).reshape(-1, num_samples) # [B*H, N] + targets = y.reshape(-1) # [B*H] + + losses["CRPS"] = crps_fn(preds=ensemble, target=targets) + + # Calculate the GMM loss + losses["nll/val"] = self._calculate_nll(y_hat, y) + # Collapse to mixture mean for further metrics + y_hat = self.model._gmm_to_prediction(y_hat) + return losses def _calculate_step_metrics( - self, - y: torch.Tensor, - y_hat: torch.Tensor, + self, + y: torch.Tensor, + y_hat: torch.Tensor, ) -> tuple[np.array, np.array]: """Calculate the MAE and MSE at each forecast step""" mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0).cpu().numpy() mse_each_step = torch.mean((y_hat - y) ** 2, dim=0).cpu().numpy() - + return mae_each_step, mse_each_step - + def _store_val_predictions(self, batch: TensorBatch, y_hat: torch.Tensor) -> None: """Internally store the validation predictions""" - - taregt_key = self.model._target_key - y = batch[taregt_key][:, -self.model.forecast_len :].cpu().numpy() - y_hat = y_hat.cpu().numpy() - ids = batch[f"{taregt_key}_id"].cpu().numpy() + target_key = self.model._target_key + + y = batch[target_key][:, -self.model.forecast_len :].cpu() + ids = batch[f"{target_key}_id"].cpu().numpy() init_times_utc = pd.to_datetime( - batch[f"{taregt_key}_time_utc"][:, self.model.history_len+1] - .cpu().numpy().astype("datetime64[ns]") + batch[f"{target_key}_time_utc"][:, self.model.history_len + 1] + .cpu() + .numpy() + .astype("datetime64[ns]") ) - if self.model.use_quantile_regression: + data_vars = { + "y": (["sample_num", "forecast_step"], y.numpy()), + } + coords = { + "ids": ("sample_num", ids), + "init_times_utc": ("sample_num", init_times_utc), + } + + if self.model.use_gmm: + # Parse GMM parameters from the raw output + mus, sigmas, pis = self.model._parse_gmm_params(y_hat) + + # Move tensors to CPU and convert to numpy for storage + mus = mus.cpu().numpy() + sigmas = sigmas.cpu().numpy() + pis = pis.cpu().numpy() + + # Store parameters for each component + for i in range(self.model.num_gmm_components): + data_vars[f"gmm_mean_{i}"] = ( + ["sample_num", "forecast_step"], + mus[:, :, i], + ) + data_vars[f"gmm_std_{i}"] = ( + ["sample_num", "forecast_step"], + sigmas[:, :, i], + ) + data_vars[f"gmm_weight_{i}"] = ( + ["sample_num", "forecast_step"], + pis[:, :, i], + ) + + # Also store the point prediction (mixture mean) + y_pred = (pis * mus).sum(axis=-1) + data_vars["y_pred"] = (["sample_num", "forecast_step"], y_pred) + + elif self.model.use_quantile_regression: + y_hat = y_hat.cpu().numpy() p_levels = self.model.output_quantiles + data_vars["y_hat"] = (["sample_num", "forecast_step", "p_level"], y_hat) + coords["p_level"] = p_levels + else: + # Handle the simple point prediction case + y_hat = y_hat.cpu().numpy() p_levels = [0.5] - y_hat = y_hat[..., None] + data_vars["y_hat"] = ( + ["sample_num", "forecast_step", "p_level"], + y_hat[..., None], + ) + coords["p_level"] = p_levels ds_preds_batch = xr.Dataset( - data_vars=dict( - y_hat=(["sample_num", "forecast_step", "p_level"], y_hat), - y=(["sample_num", "forecast_step"], y), - ), - coords=dict( - ids=("sample_num", ids), - init_times_utc=("sample_num", init_times_utc), - p_level=p_levels, - ), + data_vars=data_vars, + coords=coords, ) + self.all_val_results.append(ds_preds_batch) def on_validation_epoch_start(self): @@ -190,9 +287,9 @@ def on_validation_epoch_start(self): # Set up stores which we will fill during validation self.all_val_results: list[xr.Dataset] = [] self._val_horizon_maes: list[np.array] = [] - if self.current_epoch==0: + if self.current_epoch == 0: self._val_persistence_horizon_maes: list[np.array] = [] - + # Plot some sample forecasts val_dataset = self.trainer.val_dataloaders.dataset @@ -201,23 +298,23 @@ def on_validation_epoch_start(self): for plot_num in range(num_figures): idxs = np.arange(plots_per_figure) + plot_num * plots_per_figure - idxs = idxs[idxs None: # Calculate the horizon MAE/MSE metrics if self.model.use_quantile_regression: y_hat_mid = self.model._quantiles_to_prediction(y_hat) + elif self.model.use_gmm: + y_hat_mid = self.model._gmm_to_prediction(y_hat) else: y_hat_mid = y_hat @@ -261,21 +360,24 @@ def validation_step(self, batch: TensorBatch, batch_idx: int) -> None: # Calculate the persistance losses - we only need to do this once per training run # not every epoch - if self.current_epoch==0: + if self.current_epoch == 0: y_persist = ( - batch[self.model._target_key][:, -(self.model.forecast_len+1)] - .unsqueeze(1).expand(-1, self.model.forecast_len) + batch[self.model._target_key][:, -(self.model.forecast_len + 1)] + .unsqueeze(1) + .expand(-1, self.model.forecast_len) + ) + mae_step_persist, mse_step_persist = self._calculate_step_metrics( + y, y_persist ) - mae_step_persist, mse_step_persist = self._calculate_step_metrics(y, y_persist) self._val_persistence_horizon_maes.append(mae_step_persist) losses.update( { - "MAE/val_persistence": mae_step_persist.mean(), - "MSE/val_persistence": mse_step_persist.mean() + "MAE/val_persistence": mae_step_persist.mean(), + "MSE/val_persistence": mse_step_persist.mean(), } ) - # Log the metrics + # Log the metrics self.log_dict(losses, on_step=False, on_epoch=True) def on_validation_epoch_end(self) -> None: @@ -288,13 +390,23 @@ def on_validation_epoch_end(self) -> None: self._val_horizon_maes = [] # We only run this on the first epoch - if self.current_epoch==0: - val_persistence_horizon_maes = np.mean(self._val_persistence_horizon_maes, axis=0) + if self.current_epoch == 0: + val_persistence_horizon_maes = np.mean( + self._val_persistence_horizon_maes, axis=0 + ) self._val_persistence_horizon_maes = [] if isinstance(self.logger, pl.loggers.WandbLogger): - # Calculate and log extreme error metrics - val_error = ds_val_results["y"] - ds_val_results["y_hat"].sel(p_level=0.5) + # Determine the point prediction based on the model type + if self.model.use_gmm: + # For GMM, the point prediction 'y_pred' (mixture mean) is already calculated + point_prediction = ds_val_results["y_pred"] + else: + # For Quantiles or simple forecasts, use the median (p_level=0.5) + point_prediction = ds_val_results["y_hat"].sel(p_level=0.5) + + # Calculate the error based on the correct point prediction + val_error = ds_val_results["y"] - point_prediction # Factor out this part of the string for brevity below s = "error_extremes/{}_percentile_median_forecast_error" @@ -320,28 +432,32 @@ def on_validation_epoch_end(self) -> None: wandb_log_dir = self.logger.experiment.dir filepath = f"{wandb_log_dir}/validation_results.netcdf" ds_val_results.to_netcdf(filepath) - - # Uplodad to wandb - self.logger.experiment.save(filepath, base_path=wandb_log_dir, policy="now") - + + # Uplodad to wandb + self.logger.experiment.save( + filepath, base_path=wandb_log_dir, policy="now" + ) + # Create the horizon accuracy curve horizon_mae_plot = wandb_line_plot( - x=np.arange(self.model.forecast_len), + x=np.arange(self.model.forecast_len), y=val_horizon_maes, xlabel="Horizon step", ylabel="MAE", title="Val horizon loss curve", ) - + wandb.log({"val_horizon_mae_plot": horizon_mae_plot}) # Create persistence horizon accuracy curve but only on first epoch - if self.current_epoch==0: + if self.current_epoch == 0: persist_horizon_mae_plot = wandb_line_plot( - x=np.arange(self.model.forecast_len), + x=np.arange(self.model.forecast_len), y=val_persistence_horizon_maes, xlabel="Horizon step", ylabel="MAE", title="Val persistence horizon loss curve", ) - wandb.log({"persistence_val_horizon_mae_plot": persist_horizon_mae_plot}) + wandb.log( + {"persistence_val_horizon_mae_plot": persist_horizon_mae_plot} + ) diff --git a/pvnet/training/plots.py b/pvnet/training/plots.py index 6bb82b79..9629e512 100644 --- a/pvnet/training/plots.py +++ b/pvnet/training/plots.py @@ -1,20 +1,20 @@ """Plots logged during training""" + from collections.abc import Sequence import matplotlib.pyplot as plt import pandas as pd -import pylab import torch import wandb from ocf_data_sampler.numpy_sample.common_types import TensorBatch def wandb_line_plot( - x: Sequence[float], - y: Sequence[float], - xlabel: str, - ylabel: str, - title: str | None = None + x: Sequence[float], + y: Sequence[float], + xlabel: str, + ylabel: str, + title: str | None = None, ) -> wandb.plot.CustomChart: """Make a wandb line plot""" data = [[xi, yi] for (xi, yi) in zip(x, y)] @@ -25,62 +25,123 @@ def wandb_line_plot( def plot_sample_forecasts( batch: TensorBatch, y_hat: torch.Tensor, - quantiles: list[float] | None, + model, key_to_plot: str, ) -> plt.Figure: """Plot a batch of data and the forecast from that batch""" - y = batch[key_to_plot].cpu().numpy() - y_hat = y_hat.cpu().numpy() + y = batch[key_to_plot].cpu() + y_hat = y_hat.cpu() ids = batch[f"{key_to_plot}_id"].cpu().numpy().squeeze() times_utc = pd.to_datetime( - batch[f"{key_to_plot}_time_utc"].cpu().numpy().squeeze().astype("datetime64[ns]") + batch[f"{key_to_plot}_time_utc"] + .cpu() + .numpy() + .squeeze() + .astype("datetime64[ns]") ) batch_size = y.shape[0] fig, axes = plt.subplots(4, 4, figsize=(16, 16)) for i, ax in enumerate(axes.ravel()[:batch_size]): + # Get the forecast-only part of the ground truth and time + y_true_forecast = y[i, -model.forecast_len :] + times_forecast = times_utc[i, -model.forecast_len :] + + # Plot ground truth + ax.plot( + times_forecast, + y_true_forecast, + marker=".", + color="k", + label=r"True Value ($y$)", + ) + + if model.use_gmm: + mus, sigmas, pis = model._parse_gmm_params(y_hat[i : i + 1]) + mus, sigmas, pis = mus.squeeze(0), sigmas.squeeze(0), pis.squeeze(0) + + mixture_mean = torch.sum(pis * mus, dim=-1) + mixture_variance = torch.sum( + pis * (mus.pow(2) + sigmas.pow(2)), dim=-1 + ) - mixture_mean.pow(2) + mixture_std = torch.sqrt(mixture_variance.clamp(min=1e-6)) - ax.plot(times_utc[i], y[i], marker=".", color="k", label=r"$y$") + ax.plot( + times_forecast, + mixture_mean.numpy(), + marker=".", + color="red", + label=r"Mixture Mean", + ) + + # NOTE: The 90% band below is an approximation. + # We moment-match the GMM to a single Normal (mean/variance) and use ±1.645 * std. + # This will not capture asymmetry or multi-modality of the true GMM... + lower_bound = mixture_mean - 1.645 * mixture_std + upper_bound = mixture_mean + 1.645 * mixture_std + ax.fill_between( + times_forecast, + lower_bound.numpy(), + upper_bound.numpy(), + color="red", + alpha=0.2, + label="90% Confidence", + ) + + elif model.use_quantile_regression: + y_hat_i = y_hat[i] + quantiles = model.output_quantiles + median_idx = quantiles.index(0.5) - if quantiles is None: ax.plot( - times_utc[i][-len(y_hat[i]) :], - y_hat[i], - marker=".", - color="r", - label=r"$\hat{y}$", + times_forecast, + y_hat_i[:, median_idx], + marker=".", + color="blue", + label=r"Median", ) - else: - cm = pylab.get_cmap("twilight") - for nq, q in enumerate(quantiles): - ax.plot( - times_utc[i][-len(y_hat[i]) :], - y_hat[i, :, nq], - color=cm(q), - label=r"$\hat{y}$" + f"({q})", - alpha=0.7, + + num_quantiles = len(quantiles) + for j in range(num_quantiles // 2): + l_q, u_q = quantiles[j], quantiles[num_quantiles - 1 - j] + ax.fill_between( + times_forecast, + y_hat_i[:, j], + y_hat_i[:, num_quantiles - 1 - j], + alpha=0.2, + color="blue", + label=f"{l_q*100:.0f}-{u_q*100:.0f}%", ) + else: + ax.plot( + times_forecast, + y_hat[i], + marker=".", + color="green", + label=r"Point Forecast", + ) ax.set_title(f"ID: {ids[i]} | {times_utc[i][0].date()}", fontsize="small") + xticks = [t for t in pd.to_datetime(times_forecast) if t.minute == 0][::2] + ax.set_xticks( + ticks=xticks, labels=[f"{t.hour:02}" for t in xticks], rotation=90 + ) + ax.grid(True, which="both", linestyle="--", linewidth=0.5) - xticks = [t for t in times_utc[i] if t.minute == 0][::2] - ax.set_xticks(ticks=xticks, labels=[f"{t.hour:02}" for t in xticks], rotation=90) - ax.grid() - - axes[0, 0].legend(loc="best") + handles, labels = axes[0, 0].get_legend_handles_labels() + fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.995), ncol=4) - if batch_size<16: + if batch_size < 16: for ax in axes.ravel()[batch_size:]: ax.axis("off") - + for ax in axes[-1, :]: ax.set_xlabel("Time (hour of day)") - title = f"Normed {key_to_plot.upper()} output" - - plt.suptitle(title) - plt.tight_layout() + title = f"Normalized {key_to_plot.upper()} Power" + fig.suptitle(title, fontsize=16) + plt.tight_layout(rect=[0, 0.03, 1, 0.95]) return fig diff --git a/tests/conftest.py b/tests/conftest.py index a912da4b..07ac067a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,10 +14,9 @@ from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration from pvnet.data.base_datamodule import collate_fn -from pvnet.data import UKRegionalStreamedDataModule, SiteStreamedDataModule +from pvnet.data import UKRegionalStreamedDataModule, SiteStreamedDataModule from pvnet.models import LateFusionModel - - +from pvnet.models.base_model import BaseModel _top_test_directory = os.path.dirname(os.path.realpath(__file__)) @@ -53,8 +52,17 @@ def session_tmp_path(tmp_path_factory): @pytest.fixture(scope="session") def sat_zarr_path(session_tmp_path) -> str: variables = [ - "IR_016", "IR_039", "IR_087", "IR_097", "IR_108", "IR_120", - "IR_134", "VIS006", "VIS008", "WV_062", "WV_073", + "IR_016", + "IR_039", + "IR_087", + "IR_097", + "IR_108", + "IR_120", + "IR_134", + "VIS006", + "VIS008", + "WV_062", + "WV_073", ] times = pd.date_range("2023-01-01 00:00", "2023-01-01 23:55", freq="5min") y = np.linspace(start=4191563, stop=5304712, num=100) @@ -90,7 +98,7 @@ def ukv_zarr_path(session_tmp_path) -> str: steps = pd.timedelta_range("0h", "24h", freq="1h") x = np.linspace(-239_000, 857_000, 200) y = np.linspace(-183_000, 1425_000, 200) - + coords = ( ("init_time", init_times), ("variable", variables), @@ -148,18 +156,22 @@ def gsp_zarr_path(session_tmp_path) -> str: times = pd.date_range("2023-01-01 00:00", "2023-01-02 00:00", freq="30min") gsp_ids = np.arange(0, 318) capacity = np.ones((len(times), len(gsp_ids))) - generation = np.random.uniform(0, 200, size=(len(times), len(gsp_ids))).astype(np.float32) + generation = np.random.uniform(0, 200, size=(len(times), len(gsp_ids))).astype( + np.float32 + ) coords = ( ("datetime_gmt", times), ("gsp_id", gsp_ids), ) - ds_uk_gsp = xr.Dataset({ - "capacity_mwp": xr.DataArray(capacity, coords=coords), - "installedcapacity_mwp": xr.DataArray(capacity, coords=coords), - "generation_mw": xr.DataArray(generation, coords=coords), - }) + ds_uk_gsp = xr.Dataset( + { + "capacity_mwp": xr.DataArray(capacity, coords=coords), + "installedcapacity_mwp": xr.DataArray(capacity, coords=coords), + "generation_mw": xr.DataArray(generation, coords=coords), + } + ) zarr_path = session_tmp_path / "uk_gsp.zarr" ds_uk_gsp.to_zarr(zarr_path) @@ -179,12 +191,12 @@ def site_data_paths(session_tmp_path) -> tuple[str, str]: coords = (("time_utc", times), ("site_id", site_ids)) generation_data = np.random.uniform( - low=0, - high=200, - size=tuple(len(coord_values) for _, coord_values in coords) + low=0, high=200, size=tuple(len(coord_values) for _, coord_values in coords) ).astype(np.float32) - ds_gen = xr.DataArray(generation_data, coords=coords).to_dataset(name="generation_kw") + ds_gen = xr.DataArray(generation_data, coords=coords).to_dataset( + name="generation_kw" + ) df_meta = pd.DataFrame( { @@ -205,15 +217,13 @@ def site_data_paths(session_tmp_path) -> tuple[str, str]: @pytest.fixture(scope="session") def uk_data_config_path( - session_tmp_path, - sat_zarr_path, - ukv_zarr_path, - ecmwf_zarr_path, - gsp_zarr_path -) -> str: - + session_tmp_path, sat_zarr_path, ukv_zarr_path, ecmwf_zarr_path, gsp_zarr_path +) -> str: + # Populate the config with the generated zarr paths - config = load_yaml_configuration(f"{_top_test_directory}/test_data/uk_data_config.yaml") + config = load_yaml_configuration( + f"{_top_test_directory}/test_data/uk_data_config.yaml" + ) config.input_data.nwp["ukv"].zarr_path = str(ukv_zarr_path) config.input_data.nwp["ecmwf"].zarr_path = str(ecmwf_zarr_path) config.input_data.satellite.zarr_path = str(sat_zarr_path) @@ -226,15 +236,17 @@ def uk_data_config_path( @pytest.fixture(scope="session") def site_data_config_path( - session_tmp_path, - sat_zarr_path, - ukv_zarr_path, - ecmwf_zarr_path, + session_tmp_path, + sat_zarr_path, + ukv_zarr_path, + ecmwf_zarr_path, site_data_paths, -) -> str: - +) -> str: + # Populate the config with the generated zarr paths - config = load_yaml_configuration(f"{_top_test_directory}/test_data/site_data_config.yaml") + config = load_yaml_configuration( + f"{_top_test_directory}/test_data/site_data_config.yaml" + ) config.input_data.nwp["ukv"].zarr_path = str(ukv_zarr_path) config.input_data.nwp["ecmwf"].zarr_path = str(ecmwf_zarr_path) config.input_data.satellite.zarr_path = str(sat_zarr_path) @@ -309,7 +321,7 @@ def site_encoder_model_kwargs() -> dict: sequence_length=60 // 15 + 1, num_sites=1, out_features=128, - target_key_to_use="site" + target_key_to_use="site", ) @@ -345,7 +357,6 @@ def raw_late_fusion_model_kwargs(model_minutes_kwargs) -> dict: image_size_pixels=12, ), }, - add_image_embedding_channel=True, output_network=dict( _target_="pvnet.models.late_fusion.linear_networks.networks.ResFCNet", @@ -355,7 +366,7 @@ def raw_late_fusion_model_kwargs(model_minutes_kwargs) -> dict: res_block_layers=2, dropout_frac=0.0, ), - location_id_mapping={i:i for i in range(1, 318)}, + location_id_mapping={i: i for i in range(1, 318)}, embedding_dim=16, include_sun=True, include_gsp_yield_history=True, @@ -400,22 +411,72 @@ def raw_late_fusion_model_kwargs_site_history(model_minutes_kwargs) -> dict: include_time=True, include_gsp_yield_history=False, include_site_yield_history=True, - forecast_minutes=480, + forecast_minutes=480, history_minutes=60, interval_minutes=15, ) @pytest.fixture() -def late_fusion_model_kwargs_site_history(raw_late_fusion_model_kwargs_site_history) -> dict: +def late_fusion_model_kwargs_site_history( + raw_late_fusion_model_kwargs_site_history, +) -> dict: return hydra.utils.instantiate(raw_late_fusion_model_kwargs_site_history) @pytest.fixture() -def late_fusion_model_site_history(late_fusion_model_kwargs_site_history) -> LateFusionModel: +def late_fusion_model_site_history( + late_fusion_model_kwargs_site_history, +) -> LateFusionModel: return LateFusionModel(**late_fusion_model_kwargs_site_history) @pytest.fixture() def late_fusion_quantile_model(late_fusion_model_kwargs) -> LateFusionModel: return LateFusionModel(output_quantiles=[0.1, 0.5, 0.9], **late_fusion_model_kwargs) + + +@pytest.fixture() +def gmm_model_factory(): + """ + Factory for a minimal GMM-capable model that uses BaseModel's GMM helpers + without pulling in encoders or datamodule complexity. + Usage in tests: model = gmm_model_factory(forecast_len=H, num_components=K) + """ + + class _MinimalGMMModel(BaseModel): + def __init__(self, forecast_len=3, num_components=2): + super().__init__( + history_minutes=0, + forecast_minutes=forecast_len * 30, # interval_minutes=30 + output_quantiles=None, + num_gmm_components=num_components, + interval_minutes=30, + ) + self.include_sat = False + self.include_nwp = False + self.include_sun = False + + def factory(forecast_len=3, num_components=2): + return _MinimalGMMModel( + forecast_len=forecast_len, num_components=num_components + ) + + return factory + + +@pytest.fixture() +def build_y_gmm_from_params(): + """ + Returns a function that stacks (mus, sigma_raws, logits) into the flat y_gmm + shape expected by BaseModel._parse_gmm_params: [B, H*K*3]. + """ + + def _build( + mus: torch.Tensor, sigma_raws: torch.Tensor, logits: torch.Tensor + ) -> torch.Tensor: + B, H, K = mus.shape + params = torch.stack([mus, sigma_raws, logits], dim=-1) # [B, H, K, 3] + return params.reshape(B, H * K * 3) + + return _build diff --git a/tests/models/test_gmm_basemodel.py b/tests/models/test_gmm_basemodel.py new file mode 100644 index 00000000..442297d7 --- /dev/null +++ b/tests/models/test_gmm_basemodel.py @@ -0,0 +1,111 @@ +import pytest +import torch +from pvnet.models.base_model import BaseModel + + +def test_parse_gmm_params_shapes_and_constraints( + gmm_model_factory, build_y_gmm_from_params +): + torch.manual_seed(0) + + B, H, K = 4, 5, 3 + model = gmm_model_factory(forecast_len=H, num_components=K) + + mus = torch.randn(B, H, K) + sigma_raws = torch.randn(B, H, K) # will go through softplus + 1e-3 + logits = torch.randn(B, H, K) # will go through softmax + + y_gmm = build_y_gmm_from_params(mus, sigma_raws, logits) + out_mus, out_sigmas, out_pis = model._parse_gmm_params(y_gmm) + + assert out_mus.shape == (B, H, K) + assert out_sigmas.shape == (B, H, K) + assert out_pis.shape == (B, H, K) + + assert torch.all(out_sigmas > 0) + assert torch.all(out_sigmas >= 1e-3) + + sums = out_pis.sum(dim=-1) # [B, H] + assert torch.allclose(sums, torch.ones_like(sums), atol=1e-6) + assert torch.all(out_pis >= 0) + + +def test_parse_gmm_params_softplus_offset(gmm_model_factory, build_y_gmm_from_params): + """Check the exact softplus + 1e-3 transform for sigma.""" + B, H, K = 1, 2, 2 + model = gmm_model_factory(forecast_len=H, num_components=K) + + mus = torch.zeros(B, H, K) + sigma_raws = torch.tensor( + [[[-10.0, 0.0], [-5.0, 1.5]]], dtype=torch.float + ) # [1,2,2] + logits = torch.zeros(B, H, K) + + y_gmm = build_y_gmm_from_params(mus, sigma_raws, logits) + _, sigmas, _ = model._parse_gmm_params(y_gmm) + + expected = torch.nn.functional.softplus(sigma_raws) + 1e-3 + assert torch.allclose(sigmas, expected, atol=1e-7) + + +def test_gmm_to_prediction_matches_manual_expectation( + gmm_model_factory, build_y_gmm_from_params +): + """ + Build a tiny case we can compute by hand. + Use equal logits -> equal weights, and simple means. + """ + B, H, K = 1, 2, 2 + model = gmm_model_factory(forecast_len=H, num_components=K) + + mus = torch.tensor([[[2.0, 6.0], [10.0, 14.0]]]) # [1,2,2] + sigma_raws = torch.zeros(B, H, K) + logits = torch.zeros(B, H, K) # equal weights + + y_gmm = build_y_gmm_from_params(mus, sigma_raws, logits) + pred = model._gmm_to_prediction(y_gmm) # [B, H] + + expected = torch.tensor([[4.0, 12.0]]) + assert torch.allclose(pred, expected, atol=1e-6) + + +def test_sample_from_gmm_empirical_mean_near_expectation(gmm_model_factory): + """ + Sampling should yield an empirical mean near the analytic mixture mean. + Use small variances and many samples to tighten the estimate. + """ + torch.manual_seed(123) + + B, H, K = 2, 3, 3 + model = gmm_model_factory(forecast_len=H, num_components=K) + + mus = torch.tensor( + [ + [[0.0, 5.0, 10.0], [1.0, 6.0, 11.0], [2.0, 7.0, 12.0]], + [[3.0, 8.0, 13.0], [4.0, 9.0, 14.0], [5.0, 10.0, 15.0]], + ] + ) + sigmas = torch.full_like(mus, 0.05) + + logits = torch.tensor([0.0, 1.0, 2.0]).repeat(B, H, 1) + pis = torch.softmax(logits, dim=-1) + + n_samples = 5000 + samples = model._sample_from_gmm(mus, sigmas, pis, n_samples=n_samples) # [S,B,H] + + empirical_mean = samples.mean(dim=0) # [B,H] + analytic_mean = (pis * mus).sum(dim=-1) + + assert torch.allclose(empirical_mean, analytic_mean, atol=0.15) + + +def test_quantiles_and_gmm_mutually_exclusive(): + """Constructor should forbid using both quantiles and GMM simultaneously.""" + with pytest.raises(ValueError): + _ = BaseModel( + history_minutes=0, + forecast_minutes=60, + output_quantiles=[0.1, 0.5, 0.9], + num_gmm_components=2, + interval_minutes=30, + )