Skip to content

Conversation

murphytarra
Copy link

Switchable Probabilistic Output: Quantile Regression ⇄ Gaussian Mixture Model (GMM)

Summary

This PR adds a configurable probabilistic head so users can choose between the existing quantile regression output and a new Gaussian Mixture Model (GMM) output.
When GMM is selected, the model predicts per-component weights, means, and standard deviations instead of fixed quantiles.

  • Main files touched: base_model.py, lightning_module.py

What’s changed

  • New output mode: gmm
    • Predicts K mixture weights (softmax-normalized), means, and stds (softplus-constrained).
    • Trains with negative log-likelihood of the Gaussian mixture.
  • Existing output mode: quantile
    • Unchanged behaviour; trains with pinball loss.
  • Inference utilities for GMM:
    • Compute mean/variance from mixture parameters. This is in each epoch update.
  • Tests added (finishing touches pending; to be pushed shortly).

API / Configuration

Minimal, explicit configuration:

# model config
num_gmm_components: 3            # only used if output_distribution == "gmm"

Use above instead of "num_quantiles" - if both given flagged error.


Output tensors

Quantile mode (unchanged):

  • Shape: [B, T, Q] (batch, horizon, num_quantiles)

GMM mode:

  • Weights: [B, T, K] (softmax over K)
  • Means: [B, T, K]
  • Stds: [B, T, K] (softplus > 0)

Internally we apply softmax to weights and softplus to stds to ensure valid parameters.


Backward compatibility

  • Default keeps quantile mode, so existing training/eval pipelines continue to work untouched.
  • If you opt into gmm, you will get different output heads and a different training loss

Training / Loss

  • Quantile: pinball loss over requested quantiles (unchanged).
  • GMM: exact negative log-likelihood of a K-component diagonal Gaussian mixture.

Evaluation & Inference Notes

  • For deterministic point forecasts in GMM mode, use mixture expected value:

    $$\hat{y} = \sum_{k=1}^{K} \pi_k \mu_k$$


Migration guide

Staying with quantile: no action required.

Switching to gmm:

  1. num_gmm_components: K in config file and get rid of num_quantiles.
  2. Get rid of num_quantiles

Checklist:

  • [ x] My code follows OCF's coding style guidelines
  • [ x] I have performed a self-review of my own code
  • I have made corresponding changes to the documentation
  • [ x] I have added tests that prove my fix is effective or that my feature works
  • [ x] I have checked my code and corrected any misspellings

@murphytarra murphytarra changed the title GOSC: Probabilistic Machine Learning for Solar Forecasting: Applying Gaussian Mixture Models for output GSOC: Probabilistic Machine Learning for Solar Forecasting: Applying Gaussian Mixture Models for output Aug 25, 2025
history_minutes: int,
forecast_minutes: int,
output_quantiles: list[float] | None = None,
output_quantiles: Optional[list[float]] | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be simplified to just:

output_quantiles: Optional[list[float]] = None,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

coolio, will do :)

pis = F.softmax(logits, dim=-1)
return mus, sigmas, pis

def _quantiles_to_prediction(self, y_quantiles):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To: def _quantiles_to_prediction(self, y_quantiles: torch.Tensor) -> torch.Tensor:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slotted in!

y_gmm: (batch, forecast_len * num_components * 3)
y_true: (batch, forecast_len)
"""
mus, sigmas, pis = self._parse_gmm_params(y_gmm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgive me if I am wrong - but I think this should be updated to:

self.model._parse_gmm_params(y_gmm)

Method is in base_model right - and Lightning should hold an instance of the model itself i.e. self.model

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup! resolved :)

if self.model.use_quantile_regression:
losses["quantile_loss"] = self._calculate_quantile_loss(y_hat, y)
y_hat = self.model._quantiles_to_prediction(y_hat)
losses["quantile_loss"] = self.model._calculate_quantile_loss(y_hat, y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In line with line 91 potential update:

self._calculate_quantile_loss(y_hat, y)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry! missed this in some of the refactoring, changed it now :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think it's fine as is - perhaps maybe better with the reversion however I feel

if self.use_quantile_regression:
# Shape: batch_size, seq_length * num_quantiles
out = out.reshape(out.shape[0], self.forecast_len, len(self.output_quantiles))
out = out.view(out.size(0), self.forecast_len, len(self.output_quantiles))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solid refinement!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:)

from pvnet.models.base_model import BaseModel


class _MinimalGMMModel(BaseModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Brilliant - thanks for adding these in! Would you mind moving _MinimalGMMModel and potentially _build_y_gmm_from_params to conftest if all OK?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good shout :) moved :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

batch: TensorBatch,
y_hat: torch.Tensor,
quantiles: list[float] | None,
model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great compatibility shift - thanks!

Not sure whether or not you feel it could be a strong addition, but what about an extra aspect that includes showing actual example samples from the mixture alongside say the mean?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the plotting currently also plots the std distribution error per step as well - is this what you mean?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah, cheers sorry. Maybe like a few random samples from the predicted GMM distribution at each step and have those individually alongside the mean - completely up to you though

@felix-e-h-p felix-e-h-p requested a review from dfulu September 1, 2025 10:12
@dfulu
Copy link
Member

dfulu commented Sep 1, 2025

Hi @murphytarra, thank you for all of your work on this. It looks really great already. I'll be doing an extra review on this PR as well as Felix.

Thought I'd drop in to say hello first, and from what I can see there are just a few places it could be tidied up a little bit

Copy link
Member

@dfulu dfulu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really good @murphytarra! I really like how you've integrated it

My comments are mostly around tidy ups. I presume you've used a linter which has the line length limit set to 80 rathe than 100. I'd prefer those line breaks to be undone where it has reduced readability

@murphytarra murphytarra requested a review from dfulu September 1, 2025 20:49
@dfulu
Copy link
Member

dfulu commented Sep 2, 2025

Hi @murphytarra thanks for the changes!

I see the tests are currently failing but I think that's due to the github workflows now being out of date. If you merge the updated main branch into this branch it should fix the tests.

I'm going to go through and tag a few more lines which have been split where I think the split reduces readability. I'd appreciate it if you could revert them

history_minutes: int,
forecast_minutes: int,
output_quantiles: list[float] | None = None,
output_quantiles: Optional[list[float]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change the type hint back here and match in the line below?

param: type | None = None is the current best practice for python > 3.10 rather than param: Optional[type] = None which was the practice in python<3.9

output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
None the output is a single value.
num_gmm_components: Number of Gaussian Mixture Model components to use for the model.
If None, output quantiles must be set. If both None, the output is a single value.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If None, output quantiles must be set. If both None, the output is a single value.
If None, output quantiles must be set. If both None, the output is a single value.

gsp_ids = np.arange(0, 318)
capacity = np.ones((len(times), len(gsp_ids)))
generation = np.random.uniform(0, 200, size=(len(times), len(gsp_ids))).astype(np.float32)
generation = np.random.uniform(0, 200, size=(len(times), len(gsp_ids))).astype(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please unsplit

) -> str:

# Populate the config with the generated zarr paths
config = load_yaml_configuration(f"{_top_test_directory}/test_data/uk_data_config.yaml")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please unsplit

) -> str:

# Populate the config with the generated zarr paths
config = load_yaml_configuration(f"{_top_test_directory}/test_data/site_data_config.yaml")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please unsplit



@pytest.fixture()
def late_fusion_model_kwargs_site_history(raw_late_fusion_model_kwargs_site_history) -> dict:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please unsplit



@pytest.fixture()
def late_fusion_model_site_history(late_fusion_model_kwargs_site_history) -> LateFusionModel:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please unsplit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants