Skip to content

Commit dfa0096

Browse files
PaulinaMartin96shravanngoswamiipenelopeysm
authored
Implement Energy Plots (#329)
* Implement Energy Plot * Add tests for Energy plot * update docs to include Energy plot example * Minor version bump from 7.1.0 to 7.2.0 * run JuliaFormatter * run Julia Formatter * Update src/plot.jl Co-authored-by: Penelope Yong <[email protected]> * Update src/plot.jl Co-authored-by: Penelope Yong <[email protected]> --------- Co-authored-by: Shravan Goswami <[email protected]> Co-authored-by: Shravan Goswami <[email protected]> Co-authored-by: Penelope Yong <[email protected]>
1 parent a2f0499 commit dfa0096

File tree

4 files changed

+123
-1
lines changed

4 files changed

+123
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "Chain types and utility functions for MCMC simulations."
6-
version = "7.1.0"
6+
version = "7.2.0"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

docs/src/statsplots.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,36 @@ plot(chn, seriestype = :violin)
133133
corner(chn)
134134
```
135135

136+
## Energy Plot
137+
138+
The energy plot is a diagnostic tool for HMC-based samplers (like NUTS) that helps diagnose sampling efficiency by visualizing the energy and energy transition distributions. This plot requires that the chain contains the internal sampler statistics `:hamiltonian_energy` and `:hamiltonian_energy_error`.
139+
140+
```@example statsplots
141+
# First, we generate a chain that includes the required sampler parameters.
142+
n_iter = 1000
143+
n_chain = 4
144+
val_params = randn(n_iter, 2, n_chain)
145+
val_energy = randn(n_iter, 1, n_chain) .+ 20
146+
val_energy_error = randn(n_iter, 1, n_chain) .* 0.5
147+
full_val = hcat(val_params, val_energy, val_energy_error)
148+
149+
parameter_names = [:a, :b, :hamiltonian_energy, :hamiltonian_energy_error]
150+
section_map = (
151+
parameters=[:a, :b],
152+
internals=[:hamiltonian_energy, :hamiltonian_energy_error],
153+
)
154+
155+
chn_energy = Chains(full_val, parameter_names, section_map)
156+
157+
# Generate the energy plot (default is a density plot).
158+
energyplot(chn_energy)
159+
```
160+
161+
```@example statsplots
162+
# The plot can also be generated as a histogram.
163+
energyplot(chn_energy, kind=:histogram)
164+
```
165+
136166
For plotting multiple parameters, ridgeline, forest and caterpillar plots can be useful.
137167

138168
## Ridgeline
@@ -156,6 +186,8 @@ forestplot(chn, chn.name_map[:parameters], hpd_val = [0.05, 0.15, 0.25], ordered
156186
## API
157187

158188
```@docs
189+
energyplot
190+
energyplot!
159191
ridgelineplot
160192
ridgelineplot!
161193
forestplot

src/plot.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,20 @@
66
@shorthands corner
77
@shorthands violinplot
88

9+
"""
10+
energyplot(chains::Chains; kind=:density, kwargs...)
11+
12+
Generate an energy plot for the samples in `chains`.
13+
14+
The energy plot is a diagnostic tool for HMC-based samplers like NUTS. It displays the distributions of the Hamiltonian energy and the energy transition (error) to diagnose sampler efficiency and identify divergences.
15+
16+
This plot is only available for chains that contain the `:hamiltonian_energy` and `:hamiltonian_energy_error` statistics in their `:internals` section.
17+
18+
# Keywords
19+
- `kind::Symbol` (default: `:density`): The type of plot to generate. Can be `:density` or `:histogram`.
20+
"""
21+
@userplot EnergyPlot
22+
923
"""
1024
ridgelineplot(chains::Chains[, params::Vector{Symbol}]; kwargs...)
1125
@@ -252,6 +266,58 @@ end
252266
end
253267
end
254268

269+
@recipe function f(p::EnergyPlot; kind = :density)
270+
chains = p.args[1]
271+
272+
if kind (:density, :histogram)
273+
error("`kind` must be one of `:density` or `:histogram`")
274+
end
275+
276+
internal_names = names(chains, :internals)
277+
required_params = [:hamiltonian_energy, :hamiltonian_energy_error]
278+
for param in required_params
279+
if param internal_names
280+
error(
281+
"`$param` not found in chain's internal parameters. Energy plots are only available for HMC/NUTS samplers.",
282+
)
283+
end
284+
end
285+
286+
pooled = pool_chain(chains)
287+
energy = vec(pooled[:, :hamiltonian_energy, :])
288+
energy_error = vec(pooled[:, :hamiltonian_energy_error, :])
289+
290+
mean_energy = mean(energy)
291+
std_energy = std(energy)
292+
centered_energy = (energy .- mean_energy) ./ std_energy
293+
scaled_energy_error = energy_error ./ std_energy
294+
295+
title := "Energy Plot"
296+
xaxis := "Standardized Energy"
297+
yaxis := "Density"
298+
legend := :topright
299+
300+
@series begin
301+
seriestype := kind
302+
label := "Marginal Energy"
303+
fillrange --> 0
304+
fillalpha --> 0.5
305+
normalize --> true
306+
bins --> 50
307+
centered_energy
308+
end
309+
310+
@series begin
311+
seriestype := kind
312+
label := "Energy Transition"
313+
fillrange --> 0
314+
fillalpha --> 0.5
315+
normalize --> true
316+
bins --> 50
317+
scaled_energy_error
318+
end
319+
end
320+
255321
@recipe function f(chains::Chains, parameters::AbstractVector{Symbol}; colordim = :chain)
256322
colordim != :chain && error(
257323
"Symbol names are interpreted as parameter names, only compatible with ",

test/plot_test.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ n_chain = 3
1313
val = randn(n_iter, n_name, n_chain) .+ [1, 2, 3]'
1414
val = hcat(val, rand(1:2, n_iter, 1, n_chain))
1515

16+
# This chain is missing the required energy parameters for the energyplot.
1617
chn = Chains(val)
1718

1819
# Silence all warnings.
@@ -92,6 +93,29 @@ Logging.disable_logging(Logging.Warn)
9293
display(plot(chn, 2))
9394
display(plot(chn, 2, colordim = :parameter))
9495
println()
96+
97+
@testset "Energy plot" begin
98+
# Construct a chain with the required internal parameters.
99+
val_params = randn(n_iter, 2, n_chain)
100+
val_energy = rand(n_iter, 1, n_chain) .* 10 .+ 20
101+
val_energy_error = randn(n_iter, 1, n_chain) .* 0.1
102+
full_val = hcat(val_params, val_energy, val_energy_error)
103+
104+
parameter_names = [:a, :b, :hamiltonian_energy, :hamiltonian_energy_error]
105+
section_map = (
106+
parameters = [:a, :b],
107+
internals = [:hamiltonian_energy, :hamiltonian_energy_error],
108+
)
109+
110+
chn_energy = Chains(full_val, parameter_names, section_map)
111+
112+
println("energyplot")
113+
display(energyplot(chn_energy))
114+
display(energyplot(chn_energy, kind = :histogram))
115+
println()
116+
117+
@test_throws ErrorException energyplot(chn)
118+
end
95119
end
96120

97121
# Reset log level.

0 commit comments

Comments
 (0)