Skip to content

Conversation

PaulinaMartin96
Copy link
Contributor

No description provided.

@delete-merged-branch delete-merged-branch bot deleted the branch TuringLang:main December 24, 2021 10:30
@shravanngoswamii
Copy link
Member

Commenting Paulina's Energy plot implementation for Backup and Reference:

src/plot.jl

@shorthands meanplot
@shorthands autocorplot
@shorthands mixeddensity
@shorthands pooleddensity
@shorthands traceplot
@shorthands corner
@userplot EnergyPlot
#@shorthands energyplot

struct _TracePlot; c; val; end
struct _MeanPlot; c; val;  end
struct _DensityPlot; c; val;  end
struct _HistogramPlot; c; val;  end
struct _AutocorPlot; lags; val;  end
#struct _EnergyPlot; marginal_energy; energy_transition; p_type; n_chains; end

# define alias functions for old syntax
const translationdict = Dict(
                        :traceplot => _TracePlot,
                        :meanplot => _MeanPlot,
                        :density => _DensityPlot,
                        :histogram => _HistogramPlot,
                        :autocorplot => _AutocorPlot,
                        :pooleddensity => _DensityPlot,
                      )

const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :corner, :energyplot)

@recipe f(c::Chains, s::Symbol) = c, [s]

@recipe function f(
    chains::Chains, i::Int;
    colordim = :chain,
    barbounds = (-Inf, Inf),
    maxlag = nothing,
    append_chains = false,
    plot_type = :density
)
    st = get(plotattributes, :seriestype, :traceplot)
    c = append_chains || st == :pooleddensity ? pool_chain(chains) : chains

    if colordim == :parameter
        title --> "Chain $(MCMCChains.chains(c)[i])"
        label --> string.(names(c))
        val = c.value[:, :, i]
    elseif colordim == :chain
        title --> string(names(c)[i])
        label --> map(x -> "Chain $x", MCMCChains.chains(c))
        val = c.value[:, i, :]
    else
        throw(ArgumentError("`colordim` must be one of `:chain` or `:parameter`"))
    end

    if st == :mixeddensity || st == :pooleddensity
        discrete = indiscretesupport(c, barbounds)
        st = if colordim == :chain
            discrete[i] ? :histogram : :density
        else
            # NOTE: It might make sense to overlay histograms and density plots here.
            :density
        end
        seriestype := st
    end

    if st == :autocorplot
        lags = 0:(maxlag === nothing ? round(Int, 10 * log10(length(range(c)))) : maxlag)
        ac = autocor(c; sections = nothing, lags = lags)
        ac_mat = convert(Array, ac)
        val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :]
        _AutocorPlot(lags, val)
    #elseif st == :energyplot
    #    p_type = plot_type
    #    energy_section = get(c, :hamiltonian_energy)
    #    #@show energy_section
    #    #@show params.hamiltonian_energy
    #    n_chains = (append_chains ? 1 : size(c, 3))
    #    energy_data = (append_chains ? vec(energy_section.hamiltonian_energy.data) : energy_section.hamiltonian_energy.data)
    #    mean_energy = vec(mean(energy_data, dims = 1))
    #    marginal_energy = [energy_data[:,i] .- mean_energy[i] for i in 1:n_chains]
    #    energy_transition = [energy_data[2:end,i] .- energy_data[1:end-1,i] for i in 1:n_chains]
    #    _EnergyPlot(marginal_energy, energy_transition, p_type, n_chains)
    elseif st  supportedplots
        translationdict[st](c, val)
    else
        range(c), val
    end
end

@recipe function f(p::_DensityPlot)
    xaxis --> "Sample value"
    yaxis --> "Density"
    trim --> true
    [collect(skipmissing(p.val[:,k])) for k in 1:size(p.val, 2)]
end

@recipe function f(p::_HistogramPlot)
    xaxis --> "Sample value"
    yaxis --> "Frequency"
    fillalpha --> 0.7
    bins --> 25
    trim --> true
    [collect(skipmissing(p.val[:,k])) for k in 1:size(p.val, 2)]
end

@recipe function f(p::_MeanPlot)
    seriestype := :path
    xaxis --> "Iteration"
    yaxis --> "Mean"
    range(p.c), cummean(p.val)
end

@recipe function f(p::_AutocorPlot)
    seriestype := :path
    xaxis --> "Lag"
    yaxis --> "Autocorrelation"
    p.lags, p.val
end

@recipe function f(p::_TracePlot)
    seriestype := :path
    xaxis --> "Iteration"
    yaxis --> "Sample value"
    range(p.c), p.val
end

@recipe function f(
    chains::Chains,
    parameters::AbstractVector{Symbol};
    colordim = :chain
)
    colordim != :chain &&
        error("Symbol names are interpreted as parameter names, only compatible with ",
              "`colordim = :chain`")

    ret = indexin(parameters, names(chains))
    any(y === nothing for y in ret) && error("Parameter not found")

    return chains, Int.(ret)
end

@recipe function f(
    chains::Chains,
    parameters::AbstractVector{<:Integer} = Int[];
    sections = _default_sections(chains),
    width = 500,
    height = 250,
    colordim = :chain,
    append_chains = false
)
    _chains = isempty(parameters) ? Chains(chains, _clean_sections(chains, sections)) : chains
    c = append_chains ? pool_chain(_chains) : _chains
    ptypes = get(plotattributes, :seriestype, (:traceplot, :mixeddensity))
    ptypes = ptypes isa Symbol ? (ptypes,) : ptypes
    @assert all(ptype -> ptype  supportedplots, ptypes)
    ntypes = length(ptypes)
    nrows, nvars, nchains = size(c)
    isempty(parameters) && (parameters = colordim == :chain ? (1:nvars) : (1:nchains))
    N = length(parameters)

    if :corner  ptypes
        size --> (ntypes*width, N*height)
        legend --> false

        multiple_plots = N * ntypes > 1
        if multiple_plots
            layout := (N, ntypes)
        end

        i = 0
        for par in parameters
            for ptype in ptypes
                i += 1

                @series begin
                    if multiple_plots
                        subplot := i
                    end
                    colordim := colordim
                    seriestype := ptype
                    c, par
                end
            end
        end
    else
        ntypes > 1 && error(":corner is not compatible with multiple seriestypes")
        Corner(c, names(c)[parameters])
    end
end

struct Corner
    c
    parameters
end

@recipe function f(corner::Corner)
    label --> permutedims(corner.parameters)
    compact --> true
    size --> (600, 600)
    ar = collect(Array(corner.c.value[:, corner.parameters,i]) for i in chains(corner.c))
    RecipesBase.recipetype(:cornerplot, vcat(ar...))
end

#function compute_energy(
#    chains::Chains,
#    combined = false,
#    plot_type = :density
#)
#    st = get(plotattributes, :seriestype, :traceplot)
#
#    if st == :energyplot
#        p_type = plot_type
#        params = get(chains, :hamiltonian_energy)
#        n_chains = (combined ? 1 : size(chains, 3))
#        energy_data = (combined ? vec(params.hamiltonian_energy.data) : params.hamiltonian_energy.data)
#        mean_energy = vec(mean(energy_data, dims = 1))
#        marginal_energy = energy_data[:,i] .- mean_energy[i]
#        energy_transition = energy_data[2:end,i] .- energy_data[1:end-1,i]
#        _EnergyPlot(marginal_energy, energy_transition, p_type, n_chains)
#    else
#
#    end
#end

#@recipe function f(
#    chains::Chains;
#    plot_type = :density,
#    append_chains = false
#)
#
#    st = get(plotattributes, :seriestype, :traceplot)
#    if st == :energyplot
#        p_type = plot_type
#        energy_section = get(chains, :hamiltonian_energy)
#        #@show energy_section
#        #@show params.hamiltonian_energy
#        n_chains = (append_chains ? 1 : size(chains, 3))
#        energy_data = (append_chains ? vec(energy_section.hamiltonian_energy.data) : energy_section.hamiltonian_energy.data)
#        mean_energy = vec(mean(energy_data, dims = 1))
#        marginal_energy = [energy_data[:,i] .- mean_energy[i] for i in 1:n_chains]
#        energy_transition = [energy_data[2:end,i] .- energy_data[1:end-1,i] for i in 1:n_chains]
#        _EnergyPlot(marginal_energy, energy_transition, p_type, n_chains)
#    elseif st ∈ supportedplots
#        translationdict[st](c, val)
#    end
#end

function compute_energy(
    chains::Chains,
    combined = false,
    plot_type = :density
)
        p_type = plot_type
        params = get(chains, :hamiltonian_energy)
        isempty(params) && error("EnergyPlot receives a Chains object containing only the
            :internals section. Please use Chains(chain, [:internals]) to create it")
        n_chains = (combined ? 1 : size(chains, 3))
        energy_data = (combined ? vec(params.hamiltonian_energy.data) : params.hamiltonian_energy.data)
        mean_energy = vec(mean(energy_data, dims = 1))
        marginal_energy = [energy_data[:,i] .- mean_energy[i] for i in 1:n_chains]
        energy_transition = [energy_data[2:end,i] .- energy_data[1:end-1,i] for i in 1:n_chains]
        return marginal_energy, energy_transition, p_type, n_chains
    end

@recipe function f(
    p::EnergyPlot;
    combined = false,
    plot_type = :density
    )

    c = p.args[1]
    #p_type = plot_type
    #params = get(c, :hamiltonian_energy)
    #isempty(params) && error("EnergyPlot receives a Chains object containing only the
    #    :internals section. Please use Chains(chain, [:internals]) to create it")
    #n_chains = (combined ? 1 : size(c, 3))
    #energy_data = (combined ? vec(params.hamiltonian_energy.data) : params.hamiltonian_energy.data)
    #mean_energy = vec(mean(energy_data, dims = 1))
    #marginal_energy = [energy_data[:,i] .- mean_energy[i] for i in 1:n_chains]
    #energy_transition = [energy_data[2:end,i] .- energy_data[1:end-1,i] for i in 1:n_chains]
    marginal_energy, energy_transition, p_type, n_chains = compute_energy(c, combined, plot_type)
    k = 0
    for i in 1:n_chains
        k += 1
        title --> "Chain $(MCMCChains.chains(c)[i])"
        subplot := i
        @series begin
            seriestype := p_type
            label --> "Marginal energy"
            marginal_energy[i]
        end

        @series begin
            seriestype := p_type
            label --> "Energy transition"
            energy_transition[i]
        end
    end
end

#@recipe function f(p::_EnergyPlot)
#
#    k = 0
#    for i in 1:p.n_chains
#        k = 1
#        @series begin
#            subplot := i
#            seriestype := p.p_type
#            label --> "Marginal energy"
#            p.marginal_energy[i]
#        end
#
#        @series begin
#            subplot := i
#            seriestype := p.p_type
#            label --> "Energy transition"
#            p.energy_transition[i]
#        end
#    end
#end

@shravanngoswamii shravanngoswamii marked this pull request as ready for review August 17, 2025 21:22
Copy link

codecov bot commented Aug 17, 2025

Codecov Report

❌ Patch coverage is 91.89189% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 86.01%. Comparing base (a2f0499) to head (943917d).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
src/plot.jl 91.89% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #329      +/-   ##
==========================================
+ Coverage   85.81%   86.01%   +0.19%     
==========================================
  Files          20       20              
  Lines        1107     1144      +37     
==========================================
+ Hits          950      984      +34     
- Misses        157      160       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@shravanngoswamii
Copy link
Member

Model

using Turing
using StatsPlots

@model function gdemo(x, y)
    s² ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s²))
    x ~ Normal(m, sqrt(s²))
    y ~ Normal(m, sqrt(s²))
end

chain = sample(gdemo(1.5, 2), NUTS(), 1000, progress=false)

energyplot(chain)
image
energyplot(chain, kind=:histogram)
image

@shravanngoswamii
Copy link
Member

@yebai This PR is ready for review!

@shravanngoswamii shravanngoswamii requested a review from yebai August 17, 2025 21:46
@yebai yebai requested a review from penelopeysm August 17, 2025 21:50
Copy link
Member

@penelopeysm penelopeysm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some typos.

shravanngoswamii and others added 2 commits August 18, 2025 09:33
Co-authored-by: Penelope Yong <[email protected]>
Co-authored-by: Penelope Yong <[email protected]>
@shravanngoswamii
Copy link
Member

@penelopeysm Is the version bump fine? (from 7.1.0 to 7.2.0).

Copy link
Member

@penelopeysm penelopeysm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think the version bump is fine.

@shravanngoswamii
Copy link
Member

Before merging this, I just want to ask about BFMI -- Energy plot in ArViZ have this (https://python.arviz.org/en/latest/api/generated/arviz.bfmi.html#arviz.bfmi), we don't have it, is it nice addition? I mean I can always open a new PR for that but just asking before creating a issue...

@penelopeysm
Copy link
Member

I don't see a way to calculate BFMI with MCMCChains. Is there one?

If the function to calculate the statistic doesn't exist yet, I think it would make more sense to write that first, and then whether it's displayed in the plot can follow on later.

@shravanngoswamii
Copy link
Member

I don't see a way to calculate BFMI with MCMCChains. Is there one?

Nope, that's why I didn't bothered to add that in this PR... So Should I merge it?

@penelopeysm
Copy link
Member

Yup, I think separate things can be handled in separate PRs.

@shravanngoswamii shravanngoswamii merged commit dfa0096 into TuringLang:main Aug 18, 2025
10 checks passed
shravanngoswamii referenced this pull request Aug 18, 2025
…, (keep existing compat) (#483)

Co-authored-by: CompatHelper Julia <[email protected]>
Co-authored-by: Penelope Yong <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants