-
-
Notifications
You must be signed in to change notification settings - Fork 38
GSOC: Probabilistic Machine Learning for Solar Forecasting: Applying Gaussian Mixture Models for output #448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
pvnet/models/base_model.py
Outdated
history_minutes: int, | ||
forecast_minutes: int, | ||
output_quantiles: list[float] | None = None, | ||
output_quantiles: Optional[list[float]] | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be simplified to just:
output_quantiles: Optional[list[float]] = None,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
coolio, will do :)
pvnet/models/base_model.py
Outdated
pis = F.softmax(logits, dim=-1) | ||
return mus, sigmas, pis | ||
|
||
def _quantiles_to_prediction(self, y_quantiles): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To: def _quantiles_to_prediction(self, y_quantiles: torch.Tensor) -> torch.Tensor:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slotted in!
pvnet/training/lightning_module.py
Outdated
y_gmm: (batch, forecast_len * num_components * 3) | ||
y_true: (batch, forecast_len) | ||
""" | ||
mus, sigmas, pis = self._parse_gmm_params(y_gmm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgive me if I am wrong - but I think this should be updated to:
self.model._parse_gmm_params(y_gmm)
Method is in base_model right - and Lightning should hold an instance of the model itself i.e. self.model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup! resolved :)
pvnet/training/lightning_module.py
Outdated
if self.model.use_quantile_regression: | ||
losses["quantile_loss"] = self._calculate_quantile_loss(y_hat, y) | ||
y_hat = self.model._quantiles_to_prediction(y_hat) | ||
losses["quantile_loss"] = self.model._calculate_quantile_loss(y_hat, y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In line with line 91 potential update:
self._calculate_quantile_loss(y_hat, y)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry! missed this in some of the refactoring, changed it now :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think it's fine as is - perhaps maybe better with the reversion however I feel
if self.use_quantile_regression: | ||
# Shape: batch_size, seq_length * num_quantiles | ||
out = out.reshape(out.shape[0], self.forecast_len, len(self.output_quantiles)) | ||
out = out.view(out.size(0), self.forecast_len, len(self.output_quantiles)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solid refinement!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:)
tests/models/test_gmm_basemodel.py
Outdated
from pvnet.models.base_model import BaseModel | ||
|
||
|
||
class _MinimalGMMModel(BaseModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Brilliant - thanks for adding these in! Would you mind moving _MinimalGMMModel and potentially _build_y_gmm_from_params to conftest if all OK?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good shout :) moved :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
batch: TensorBatch, | ||
y_hat: torch.Tensor, | ||
quantiles: list[float] | None, | ||
model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great compatibility shift - thanks!
Not sure whether or not you feel it could be a strong addition, but what about an extra aspect that includes showing actual example samples from the mixture alongside say the mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the plotting currently also plots the std distribution error per step as well - is this what you mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah, cheers sorry. Maybe like a few random samples from the predicted GMM distribution at each step and have those individually alongside the mean - completely up to you though
Hi @murphytarra, thank you for all of your work on this. It looks really great already. I'll be doing an extra review on this PR as well as Felix. Thought I'd drop in to say hello first, and from what I can see there are just a few places it could be tidied up a little bit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really good @murphytarra! I really like how you've integrated it
My comments are mostly around tidy ups. I presume you've used a linter which has the line length limit set to 80 rathe than 100. I'd prefer those line breaks to be undone where it has reduced readability
Hi @murphytarra thanks for the changes! I see the tests are currently failing but I think that's due to the github workflows now being out of date. If you merge the updated main branch into this branch it should fix the tests. I'm going to go through and tag a few more lines which have been split where I think the split reduces readability. I'd appreciate it if you could revert them |
history_minutes: int, | ||
forecast_minutes: int, | ||
output_quantiles: list[float] | None = None, | ||
output_quantiles: Optional[list[float]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you change the type hint back here and match in the line below?
param: type | None = None
is the current best practice for python > 3.10 rather than param: Optional[type] = None
which was the practice in python<3.9
output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to | ||
None the output is a single value. | ||
num_gmm_components: Number of Gaussian Mixture Model components to use for the model. | ||
If None, output quantiles must be set. If both None, the output is a single value. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If None, output quantiles must be set. If both None, the output is a single value. | |
If None, output quantiles must be set. If both None, the output is a single value. |
gsp_ids = np.arange(0, 318) | ||
capacity = np.ones((len(times), len(gsp_ids))) | ||
generation = np.random.uniform(0, 200, size=(len(times), len(gsp_ids))).astype(np.float32) | ||
generation = np.random.uniform(0, 200, size=(len(times), len(gsp_ids))).astype( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please unsplit
) -> str: | ||
|
||
# Populate the config with the generated zarr paths | ||
config = load_yaml_configuration(f"{_top_test_directory}/test_data/uk_data_config.yaml") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please unsplit
) -> str: | ||
|
||
# Populate the config with the generated zarr paths | ||
config = load_yaml_configuration(f"{_top_test_directory}/test_data/site_data_config.yaml") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please unsplit
|
||
|
||
@pytest.fixture() | ||
def late_fusion_model_kwargs_site_history(raw_late_fusion_model_kwargs_site_history) -> dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please unsplit
|
||
|
||
@pytest.fixture() | ||
def late_fusion_model_site_history(late_fusion_model_kwargs_site_history) -> LateFusionModel: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please unsplit
Switchable Probabilistic Output: Quantile Regression ⇄ Gaussian Mixture Model (GMM)
Summary
This PR adds a configurable probabilistic head so users can choose between the existing quantile regression output and a new Gaussian Mixture Model (GMM) output.
When GMM is selected, the model predicts per-component weights, means, and standard deviations instead of fixed quantiles.
base_model.py
,lightning_module.py
What’s changed
gmm
K
mixture weights (softmax-normalized), means, and stds (softplus-constrained).quantile
API / Configuration
Minimal, explicit configuration:
Use above instead of "num_quantiles" - if both given flagged error.
Output tensors
Quantile mode (unchanged):
[B, T, Q]
(batch, horizon, num_quantiles)GMM mode:
[B, T, K]
(softmax overK
)[B, T, K]
[B, T, K]
(softplus > 0)Backward compatibility
Training / Loss
K
-component diagonal Gaussian mixture.Evaluation & Inference Notes
For deterministic point forecasts in GMM mode, use mixture expected value:
Migration guide
Staying with quantile: no action required.
Switching to gmm:
num_gmm_components: K
in config file and get rid ofnum_quantiles
.num_quantiles
Checklist: