-
Notifications
You must be signed in to change notification settings - Fork 976
Feature/ Foundation model and Chronos-2 #2944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
- Docstring for `FoundationModel`. - Docstring for `HuggingFaceModelMixin`. - Add `probabilistic` parameter for converting probabilistic TSFMs into determinstic (might not be supported by all TSFMs). Co-authored-by: Zhihao Dai <[email protected]>
- Docstring for `Chronos2Model`. - Docstring for `_Chronos2Module`. - Add `probabilistic` parameter to convert Chronos2 into determinstic model by taking the median quantile. Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #2944 +/- ##
==========================================
- Coverage 95.52% 95.48% -0.05%
==========================================
Files 146 150 +4
Lines 15710 16198 +488
==========================================
+ Hits 15007 15466 +459
- Misses 703 732 +29 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Co-authored-by: Zhihao Dai <[email protected]>
|
@dennisbader Could you please add @abdulfatir as a reviewer? The PR is not fully ready yet due to missing parts and discussion points (see above), but we can ask @abdulfatir for review once those are completed. |
|
@daidahao, is @abdulfatir not able to review already? I believe every user should be able to add a review by default. |
Aha, my mistakes. I thought we would need to add someone explicitly to review the changes. |
abdulfatir
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@daidahao This looks great! Thank you for your effort. I mainly reviewed the Chronos-2 part and only have minor comments on that. Regarding the design and Darts-specific stuff, the maintainers may provide better feedback.
| layer_norm_epsilon: float = 1e-6, | ||
| feed_forward_proj: str = "relu", | ||
| rope_theta: float = 10000.0, | ||
| attn_implementation: Literal["eager", "sdpa"] | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw discussion somewhere on whether SDPA would require changes to torch versions. Note that the benefit that SDPA provides for Chronos-2 is relatively minor, so if needed the default may be changed to eager. See: amazon-science/chronos-forecasting#331
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@abdulfatir Great insight from amazon-science/chronos-forecasting#331 and I followed it too when it was first posted. Regarding the torch version, I am in favour of raising torch to >=2.0.0 since it was released more than two years ago and could bring performance benefits for Darts users including sdpa, torch.compile(), etc. I leave it as is for now and hear what @dennisbader would like to say.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you can go ahead an raise torch to >=2.0.0 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done by 32ae5c0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we are raising torch to >=2.0.0, would it be okay to remove attn_implementation option (and any eager implementation) as the user would not set it from Chronos2Model and it is always defaulting to sdpa anyway? @abdulfatir
dennisbader
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really great stuff @daidahao, thanks a lot 🚀 It's really nice to see how nicely it can be integrated. And I think you laid a very strong foundation for future work 🌟
I added a couple of suggestions. Let me know what you think.
Co-authored-by: Zhihao Dai <[email protected]>
|
@daidahao notebook looks great! A small comment: the example only shows past covariates, and for this specific split, using past-only covariates actually worsens the results slightly. This will be very confusing for the the reader. I would recommend showing the known-future covariates scenario in the notebook instead. That is generally the more popular setting. Of course, you should put the caveat about what things could be known in the future or could be reasonably predicted (weather) to be supplied as future covariates. |
|
@abdulfatir Thank you for the suggestions. I noticed the worse results too when creating the notebook. Unfortunately, there are no future covariates available from this particular dataset as the weather measurements came from a weather station and not a forecast. Creating future covariates like time of day is possible and but may not be as interesting as using covariates from the data source. For the sake of demonstration, I will update the notebook to use weather measurements as future covariates instead. But I will also put up a caveat about one should use weather forecasts in practice. What do you think? |
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Yes, this makes sense and is also a common approach in practice. |
dennisbader
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice updates, thanks a lot @daidahao
The example notebook looks great as well!
I added some last few minor suggestions, after that everything should be good to go 🚀
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
Also predict quantiles directly instead of sampling for probabilistic forecasts. Co-authored-by: Zhihao Dai <[email protected]>
Co-authored-by: Zhihao Dai <[email protected]>
|
@dennisbader Can I suggest ignoring files like name: darts PR workflow
on:
pull_request:
branches:
- master
paths-ignore:
- 'CHANGELOG.md'
|
Yes that sounds like a good idea :) Could we do this in a separate PR? |
|
Beautiful 😍 Thanks again for this great work, and also impressive how quickly you could set this up! 🚀 This will allow us to add new foundation models efficiently! Really nice to see this, kudos 🔥 |
|
Indeed, great work @daidahao! |
|
@abdulfatir @dennisbader Thank you both for your code reviews. Learnt a lot from this experience 😊 |
|
Yes, thank you too @abdulfatir for the helpful reviews! |
Checklist before merging this PR:
Fixes #2943.
Summary
This PR adds base
FoundationModelandChronos2Modelto Darts forecasting models.Major Changes
FoundationModelas a base class for all foundation forecasting models. All foundation models like Chronos-2 should inherit the base to make use of PyTorch datasets, optimized historical forecasting, Lightning APIs for model training (fine-tuning), checkpointing, etc.HuggingFaceModelMixinas a mixin class for foundation models that require downloading model configuration and weights from HuggingFace Hub. The class provides methods for downloading model config, weight files, and loading them into aPLForecastingModuleinstance.Chronos2Modelfor zero-shot forecasting using Amazon's pre-trained checkpoint. It supports past and future covariates and can convertsquantilesfrom Chronos-2 to DartsQuantileRegressionlikelihood model. By default, it is deterministic by outputting only the median quantile. It can be probabilistic with user-selected quantiles from pre-trained quantiles. Fine-tuning is not supported for now.huggingface-hubandsafetensors, the former for downloading model files and the other for loading model weights. Both should be lightweight enough to ship with Darts.>2.0.0for SDPA and other performance benefits.Quickstart
Chronos-2 Adaptation
Chronos-2 was ported from amazon-science/chronos-forecasting@c23d34c and I have since made some changes to integrate within Darts.
transformersandeinopslibraries.HuggingFaceModelMixin.output_attentionsoption from forward pass.QuantileRegression, and remove original loss computation in forward pass.*Outputreturn type with directtorch.Tensorto comply with DartsPLForecastingModuleinterface.einopsrearrange operations with native PyTorch tensor operations.The key principle here is to introduce as little dependencies as possible and implement it fully in PyTorch. I think using
chronos-forecastinglibrary is convenient but has the risk of conflicts with Darts dependencies in the future.Fidelity Tests
For validation, I use the
ElectricityConsumptionZurichDatasetto generate forecasts with Darts implementation and the original. Seetest_chronos2.pyfor details. Because TSFMs might be slower than other torch models, I limit the fidelity tests to 2 (probabilistic or median).Important Notes: Due to differences in probabilistic sampling methods, zero-shot forecasts obtained here would differ from those obtained using the original implementation when prediction horizon
nis larger than 1024.What is Missing
FoundationModel(usingChronos2Modelas prime example) either as part of Quickstart or User Guide.Example notebook forChronos2Modelto showcase Chronos-2 capabilities.CHANGELOG.README model catalogue update.Unit tests forChronos2Model, maybe using a tiny mock Chronos-2 configuration.Unit tests forFoundationModel.Example code inChronos2Modeldocstring.Chronos2Modelusing different datasets?Discussions
There are a few points I would like to discuss here, following the discussions in #2933:
-> UsingFoundationModelinherits fromMixedCovariatesTorchModeland thusTorchForecastingModel. This is to allow optimized historical forecasts, PyTorch data loaders, Lightning APIs to be used. Would that be confusing since it also introduces a lot more parameters fromTorchForecastingModelandPLForecastingModule?TorchForecastingModelas base forFoundationModel.Chronos-2 has its own-> Not integrated for now until fine-tuning is supported.RINormroutine, different than DartsRINorminio_processor(). See my notes in_Chronos2Module.forward()for details. What would be the best way to integrate both and to ensure fine-tuning support w/ normalized loss?I understand fine-tuning is not a priority for now. But since Chronos-2 can be trained using the quantile loss, should we allow fine-tuning Chronos-2 for our users?-> Fine-tuning not supported for now.SDPA is used default by Chronos-2 but only introduced from PyTorch 2.0.0. Should we raise-> Raised totorchto 2.0.0 or differentiate between usingtorch<2.0.0andtorch>2.0.0?>2.0.0.Other Information