Skip to content

Commit fc329c8

Browse files
Hyperparameter tuning with Hyperopt.jl (#109)
* hyperparameter tuning, passing hyperparameters as kwargs * md and script with hyperparameter tuning * function to extract and pass best hyperparameters to another train (via tune) * re-added option to pass data as in a tuple
1 parent 827834b commit fc329c8

File tree

10 files changed

+394
-15
lines changed

10 files changed

+394
-15
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1414
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
1515
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1616
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
17+
Hyperopt = "93e5fe13-2215-51db-baaf-2e9a34fb2712"
1718
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1819
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1920
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
@@ -47,6 +48,7 @@ DataFrames = "1"
4748
Downloads = "1.6.0"
4849
Flux = "0.16"
4950
ForwardDiff = "1.0.1"
51+
Hyperopt = "0.5.6"
5052
JLD2 = "0.5.13, 0.6"
5153
Lux = "1.12.4"
5254
LuxCore = "1.2.4"

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Chain = "8be319e6-bccf-4806-a6f7-6fae938471bc"
44
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
55
DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365"
66
EasyHybrid = "61bb816a-e6af-4913-ab9e-91bff2e122e3"
7+
Hyperopt = "93e5fe13-2215-51db-baaf-2e9a34fb2712"
78

89
[sources]
910
EasyHybrid = {path = ".."}

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ makedocs(;
1515
"Get Started" => "get_started.md",
1616
"Tutorial" => [
1717
"Exponential Response" => "tutorials/exponential_res.md",
18+
"Hyperparameter Tuning" => "tutorials/hyperparameter_tuning.md"
1819
],
1920
"Research" =>[
2021
"Overview" => "research/overview.md"
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
```@raw html
2+
---
3+
authors:
4+
- name: Bernhard Ahrens
5+
avatar: https://raw.githubusercontent.com/EarthyScience/EasyHybrid.jl/72c2fa9df829d46d25df15352a4b728d2dbe94ed/docs/src/assets/Bernhard_Ahrens.png
6+
link: https://www.bgc-jena.mpg.de/en/bgi/miss
7+
- name: Lazaro Alonso
8+
avatar: https://avatars.githubusercontent.com/u/19525261?v=4
9+
platform: github
10+
link: https://lazarusa.github.io
11+
12+
---
13+
14+
<Authors />
15+
```
16+
17+
# Getting Started
18+
19+
20+
### 1. Setup and Data Loading
21+
22+
Load package and synthetic dataset
23+
24+
```@example hyperparameter_tuning
25+
using EasyHybrid
26+
using CairoMakie
27+
using Hyperopt
28+
```
29+
30+
```@example hyperparameter_tuning
31+
ds = load_timeseries_netcdf("https://github.com/bask0/q10hybrid/raw/master/data/Synthetic4BookChap.nc")
32+
ds = ds[1:20000, :] # Use subset for faster execution
33+
first(ds, 5)
34+
```
35+
36+
### 2. Define the Process-based Model
37+
38+
RbQ10 model: Respiration model with Q10 temperature sensitivity
39+
40+
```@example hyperparameter_tuning
41+
function RbQ10(;ta, Q10, rb, tref = 15.0f0)
42+
reco = rb .* Q10 .^ (0.1f0 .* (ta .- tref))
43+
return (; reco, Q10, rb)
44+
end
45+
```
46+
47+
### 3. Configure Model Parameters
48+
49+
Parameter specification: (default, lower_bound, upper_bound)
50+
51+
```@example hyperparameter_tuning
52+
parameters = (
53+
rb = (3.0f0, 0.0f0, 13.0f0), # Basal respiration [μmol/m²/s]
54+
Q10 = (2.0f0, 1.0f0, 4.0f0), # Temperature sensitivity - describes factor by which respiration is increased for 10 K increase in temperature [-]
55+
)
56+
```
57+
58+
### 4. Construct the Hybrid Model
59+
60+
Define input variables
61+
62+
```@example hyperparameter_tuning
63+
forcing = [:ta] # Forcing variables (temperature)
64+
predictors = [:sw_pot, :dsw_pot] # Predictor variables (solar radiation)
65+
target = [:reco] # Target variable (respiration)
66+
```
67+
68+
Parameter classification as global, neural or fixed (difference between global and neural)
69+
70+
```@example hyperparameter_tuning
71+
global_param_names = [:Q10] # Global parameters (same for all samples)
72+
neural_param_names = [:rb] # Neural network predicted parameters
73+
```
74+
75+
Construct hybrid model
76+
77+
```@example hyperparameter_tuning
78+
hybrid_model = constructHybridModel(
79+
predictors, # Input features
80+
forcing, # Forcing variables
81+
target, # Target variables
82+
RbQ10, # Process-based model function
83+
parameters, # Parameter definitions
84+
neural_param_names, # NN-predicted parameters
85+
global_param_names, # Global parameters
86+
hidden_layers = [16, 16], # Neural network architecture
87+
activation = relu, # Activation function
88+
scale_nn_outputs = true, # Scale neural network outputs
89+
input_batchnorm = false # Apply batch normalization to inputs
90+
)
91+
```
92+
93+
### 5. Train the Model
94+
95+
```@example hyperparameter_tuning
96+
out = train(
97+
hybrid_model,
98+
ds,
99+
();
100+
nepochs = 100, # Number of training epochs
101+
batchsize = 512, # Batch size for training
102+
opt = AdamW(0.001), # Optimizer and learning rate
103+
monitor_names = [:rb, :Q10], # Parameters to monitor during training
104+
yscale = identity, # Scaling for outputs
105+
patience = 30, # Early stopping patience
106+
show_progress=false,
107+
hybrid_name="before"
108+
)
109+
```
110+
111+
```@raw html
112+
<video src="../training_history_before.mp4" controls="controls" autoplay="autoplay"></video>
113+
```
114+
115+
### 6. Check Results
116+
117+
Evolution of train and validation loss
118+
119+
```@example hyperparameter_tuning
120+
EasyHybrid.plot_loss(out, yscale = identity)
121+
```
122+
123+
Check results - what do you think - is it the true Q10 used to generate the synthetic dataset?
124+
125+
```@example hyperparameter_tuning
126+
out.train_diffs.Q10
127+
```
128+
129+
Quick scatterplot - dispatches on the output of train
130+
131+
```@example hyperparameter_tuning
132+
EasyHybrid.poplot(out)
133+
```
134+
135+
## Hyperparameter Tuning
136+
137+
EasyHybrid provides built-in hyperparameter tuning capabilities to optimize your model configuration. This is especially useful for finding the best neural network architecture, optimizer settings, and other hyperparameters.
138+
139+
### Basic Hyperparameter Tuning
140+
141+
You can use the `tune` function to automatically search for optimal hyperparameters:
142+
143+
```@example hyperparameter_tuning
144+
# Create empty model specification for tuning
145+
mspempty = ModelSpec()
146+
147+
# Define hyperparameter search space
148+
nhyper = 4
149+
ho = @thyperopt for i=nhyper,
150+
opt = [AdamW(0.01), AdamW(0.1), RMSProp(0.001), RMSProp(0.01)],
151+
input_batchnorm = [true, false]
152+
153+
hyper_parameters = (;opt, input_batchnorm)
154+
println("Hyperparameter run: ", i, " of ", nhyper, " with hyperparameters: ", hyper_parameters)
155+
156+
# Run tuning with current hyperparameters
157+
out = EasyHybrid.tune(
158+
hybrid_model,
159+
ds,
160+
mspempty;
161+
hyper_parameters...,
162+
nepochs = 10,
163+
plotting = false,
164+
show_progress = false,
165+
file_name = "test$i.jld2"
166+
)
167+
168+
out.best_loss
169+
end
170+
171+
# Get the best hyperparameters
172+
ho.minimizer
173+
printmin(ho)
174+
175+
# Train the model with the best hyperparameters
176+
best_hyperp = best_hyperparams(ho)
177+
178+
```
179+
180+
### Train model with the best hyperparameters
181+
182+
```@example hyperparameter_tuning
183+
# Run tuning with specific hyperparameters
184+
out_tuned = EasyHybrid.tune(
185+
hybrid_model,
186+
ds,
187+
mspempty;
188+
best_hyperp...,
189+
nepochs = 100,
190+
monitor_names = [:rb, :Q10],
191+
hybrid_name="after"
192+
)
193+
194+
# Check the tuned model performance
195+
out_tuned.best_loss
196+
```
197+
198+
```@raw html
199+
<video src="../training_history_after.mp4" controls="controls" autoplay="autoplay"></video>
200+
```
201+
202+
### Key Hyperparameters to Tune
203+
204+
When tuning your hybrid model, consider these important hyperparameters:
205+
206+
- **Optimizer and Learning Rate**: Try different optimizers (AdamW, RMSProp, Adam) with various learning rates
207+
- **Neural Network Architecture**: Experiment with different `hidden_layers` configurations
208+
- **Activation Functions**: Test different activation functions (relu, sigmoid, tanh)
209+
- **Batch Normalization**: Enable/disable `input_batchnorm` and other normalization options
210+
- **Batch Size**: Adjust `batchsize` for optimal training performance
211+
212+
### Tips for Hyperparameter Tuning
213+
214+
- **Start with a small search space** to get a baseline understanding
215+
- **Monitor for overfitting** by tracking validation loss
216+
- **Consider computational cost** - more hyperparameters and epochs increase training time
217+
218+
## More Examples
219+
220+
Check out the `projects/` directory for additional examples and use cases. Each project demonstrates different aspects of hybrid modeling with EasyHybrid.

projects/book_chapter/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
[deps]
22
EasyHybrid = "61bb816a-e6af-4913-ab9e-91bff2e122e3"
3+
Hyperopt = "93e5fe13-2215-51db-baaf-2e9a34fb2712"
34
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
5+
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
46
WGLMakie = "276b4fcb-3e11-5398-bf8b-a0c2d153d008"

projects/book_chapter/example_synthetic.jl

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ hybrid_model = constructHybridModel(
8888
neural_param_names, # NN-predicted parameters
8989
global_param_names, # Global parameters
9090
hidden_layers = [16, 16], # Neural network architecture
91-
activation = swish, # Activation function
91+
activation = sigmoid, # Activation function
9292
scale_nn_outputs = true, # Scale neural network outputs
93-
input_batchnorm = true # Apply batch normalization to inputs
93+
input_batchnorm = false # Apply batch normalization to inputs
9494
)
9595

9696
# =============================================================================
@@ -105,15 +105,95 @@ out = train(
105105
();
106106
nepochs = 100, # Number of training epochs
107107
batchsize = 512, # Batch size for training
108-
opt = RMSProp(0.001), # Optimizer and learning rate
108+
opt = AdamW(0.1), # Optimizer and learning rate
109109
monitor_names = [:rb, :Q10], # Parameters to monitor during training
110-
yscale = identity, # Scaling for outputs
111-
patience = 30 # Early stopping patience
110+
yscale = identity # Scaling for outputs
112111
)
113112

114113
# =============================================================================
115114
# Results Analysis
116115
# =============================================================================
117116
# Check the training differences for Q10 parameter
118117
# This shows how close the model learned the true Q10 value
119-
out.train_diffs.Q10
118+
out.train_diffs.Q10
119+
120+
using Hyperopt
121+
using Distributed
122+
using WGLMakie
123+
124+
mspempty = ModelSpec()
125+
126+
nhyper = 4
127+
ho = @thyperopt for i=nhyper,
128+
opt = [AdamW(0.01), AdamW(0.1), RMSProp(0.001), RMSProp(0.01)],
129+
input_batchnorm = [true, false]
130+
hyper_parameters = (;opt, input_batchnorm)
131+
println("Hyperparameter run: \n", i, " of ", nhyper, "\t with hyperparameters \t", hyper_parameters, "\t")
132+
out = EasyHybrid.tune(hybrid_model, ds, mspempty; hyper_parameters..., nepochs = 10, plotting = false, show_progress = false, file_name = "test$i.jld2")
133+
#out.best_loss
134+
# return a rich record for this trial (stored in ho.results[i])
135+
(out.best_loss,
136+
hyperps = hyper_parameters,
137+
ps_st = (ps = out.ps, st = out.st),
138+
file = "test$i.jld2",
139+
i = i)
140+
end
141+
142+
losses = getfield.(ho.results, :best_loss)
143+
hyperps = getfield.(ho.results, :hyperps)
144+
145+
# Helper function to make optimizer names short and readable
146+
function short_opt_name(opt)
147+
if opt isa AdamW
148+
return "AdamW(η=$(opt.eta))"
149+
elseif opt isa RMSProp
150+
return "RMSProp(η=$(opt.eta))"
151+
else
152+
return string(typeof(opt))
153+
end
154+
end
155+
156+
# Sort losses and associated data by increasing loss
157+
idx = sortperm(losses)
158+
sorted_losses = losses[idx]
159+
sorted_hyperps = hyperps[idx]
160+
161+
fig = Figure(figure_padding = 50)
162+
# Prepare tick labels with hyperparameter info for each trial (sorted)
163+
sorted_ticklabels = [
164+
join([
165+
k == :opt ? "opt=$(short_opt_name(v))" : "$k=$(repr(v))"
166+
for (k, v) in pairs(hp)
167+
], "\n")
168+
for hp in sorted_hyperps
169+
]
170+
ax = Makie.Axis(
171+
fig[1, 1];
172+
xlabel = "Trial",
173+
ylabel = "Loss",
174+
title = "Hyperparameter Tuning Results",
175+
xgridvisible = false,
176+
ygridvisible = false,
177+
xticks = (1:length(sorted_losses), sorted_ticklabels),
178+
xticklabelrotation = 45
179+
)
180+
scatter!(ax, 1:length(sorted_losses), sorted_losses; markersize=15, color=:dodgerblue)
181+
182+
183+
184+
best_idx = argmin(losses)
185+
best_trial = ho.results[best_idx]
186+
187+
best_params = best_trial.params # (ps, st)
188+
189+
# Print the best hyperparameters
190+
printmin(ho)
191+
192+
# Plot the results
193+
import Plots
194+
using Unitful
195+
Plots.plot(ho, xrotation=25, left_margin=[100mm 0mm], bottom_margin=60mm, ylab = "loss", size = (900, 900))
196+
197+
# Train the model with the best hyperparameters
198+
best_hyperp = best_hyperparams(ho)
199+
out = EasyHybrid.tune(hybrid_model, ds, mspempty; best_hyperp..., nepochs = 100)

src/EasyHybrid.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ using JLD2
2525
using StyledStrings
2626
using Printf
2727
using Reexport: @reexport
28+
using Hyperopt
2829

2930
@reexport begin
3031
import LuxCore
@@ -53,5 +54,6 @@ include("utils/show_train.jl")
5354
include("utils/helpers_for_HybridModel.jl")
5455
include("plotrecipes.jl")
5556
include("utils/helpers_data_loading.jl")
57+
include("tune.jl")
5658

5759
end

0 commit comments

Comments
 (0)