Skip to content

Conversation

PaulinaMartin96
Copy link
Contributor

@PaulinaMartin96 PaulinaMartin96 commented Jul 19, 2021

"ppcplot" function was added for plotting prior/posterior predictive checks for one or more dependent variables. As args this function receives yobs_data, the observed data for dependet variables (a vector or matrix), and ypred_data , the posterior/prior predictive results (Chains object). It plots the observed data, a sample of predictions and the predictions mean.

As kwargs, this function receives:

  • yvar_name (vector of Symbol) which contains the name of the dependent variables to be plotted,
  • plot_type which can take :density , :cumulative, and :histogram as values,
  • predictive_check for plot titles and can be :prior or :posterior (default value is :posterior)
  • n_samples which established the number o samples to be plotted (default value is 50, but when plotting it is redefined as the minimum between 50 and sample size in ypred_data).

For more than one dependet variable in a single model, yvar_name must be provided and the order in which names variables appear must be the same as in the observed data matrix. This was done in order to separate predictions for every dependent variable, because predict does not return predictions ordered by variable.

The following is a working example for a model with one dependent variable

using Turing, StatsBase, Statistics, MCMCChains, StatsPlots

@model function linear_reg(x, y, σ = 0.1) 
            β ~ Normal(1, 0.5) 
  
            for i  eachindex(y) 
                y[i] ~ Normal* x[i], σ) 
            end 
        end; 
  
σ = 0.1; f(x) = 2 * x + 0.1 * randn();   
Δ = 0.01; xs_train = 0:Δ:10; ys_train = f.(xs_train);   
xs_test = [10 + i*Δ for i in 1:100]; ys_test = f.(xs_test); 
m_train = linear_reg(xs_train, ys_train, σ);

#Prior predictive check
chain_lin_reg = sample(m_train, Prior(), 200);   
m_test_prior = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);   
predictions_prior = predict(m_test_prior, chain_lin_reg) 
ppcplot(ys_test, predictions_prior, yvar_name = [:y_var], predictive_check = :prior, plot_type = :density )

image

And for posterior predictive check

#Posterior predictive check  
chain_lin_reg = sample(m_train, NUTS(100, 0.65), 200);   
m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);   
predictions_posterior = predict(m_test, chain_lin_reg) 
ppcplot(ys_test, predictions_posterior)

Plot_type = :density
image

Plot_type = :cumulative

ppcplot(ys_test, predictions_posterior, n_samples = 20, predictive_check = :posterior, plot_type = :cumulative, size = (900, 600))

image

Plot_type = :histogram
image

Aditionally, this is a working example for a model with two dependent variables

@model function linear_reg(x, y, z, σ = 0.1) 
            β ~ Normal(0, 1)
            γ ~ Normal(0, 1)
  
            for i  eachindex(y) 
                y[i] ~ Normal* x[i], σ)
                z[i] ~ Normal* x[i], σ)    
            end 
        end; 
  
σ = 0.1; f(x) = 2 * x + 0.1 * randn(); g(x) = 4 * x + 0.4 * randn();  
Δ = 0.01; xs_train = 0:Δ:10; ys_train = f.(xs_train); zs_train = g.(xs_train); 
xs_test = [10 + i*Δ for i in 1:100]; ys_test = f.(xs_test); zs_test = g.(xs_test);  
m_train = linear_reg(xs_train, ys_train, zs_train, σ); 
  
chain_lin_reg = sample(m_train, NUTS(100, 0.65), 200); 
  
m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), Vector{Union{Missing, Float64}}(undef, length(zs_test)), σ); 
  
predictions = predict(m_test, chain_lin_reg)

var_test = hcat(ys_test, zs_test)
ppcplot(var_test, predictions, n_samples = 100, yvar_name = [:y, :z], predictive_check = :posterior, plot_type = :density, size = (900, 400))

image

ppcplot(var_test, predictions, n_samples = 30, yvar_name = [:y, :z], predictive_check = :posterior, plot_type = :cumulative, size = (900, 400))

image

var_name = [:y, :z]
ppcplot(var_test, predictions, yvar_name = var_name, n_samples = 10, predictive_check = :posterior, plot_type = :histogram, size = (900, 600))

image

@PaulinaMartin96
Copy link
Contributor Author

For this PR, should the version be 4.16.0 (after #316 ) or 5.1.0 (after #310 )?

@PaulinaMartin96 PaulinaMartin96 marked this pull request as ready for review July 23, 2021 19:32
@cpfiffer
Copy link
Member

Probably 4.16.0 since #310 is a bigger thing and probably won't have too much effect here.

@delete-merged-branch delete-merged-branch bot deleted the branch TuringLang:main December 24, 2021 10:30
@shravanngoswamii
Copy link
Member

Backup of Paulina's implementation of PPC Plots:

src/plot.jl

@shorthands meanplot
@shorthands autocorplot
@shorthands mixeddensity
@shorthands pooleddensity
@shorthands traceplot
@shorthands corner
@userplot RidgelinePlot
@userplot ForestPlot
@shorthands ppcplot

struct _TracePlot; c; val; end
struct _MeanPlot; c; val;  end
struct _DensityPlot; c; val;  end
struct _HistogramPlot; c; val;  end
struct _AutocorPlot; lags; val;  end
struct _PPCPlot; y_obs; y_pred; ymean_pred; end

# define alias functions for old syntax
const translationdict = Dict(
                        :traceplot => _TracePlot,
                        :meanplot => _MeanPlot,
                        :density => _DensityPlot,
                        :histogram => _HistogramPlot,
                        :autocorplot => _AutocorPlot,
                        :pooleddensity => _DensityPlot,
                        :ppcplot => _PPCPlot
                      )

const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :corner)

@recipe f(c::Chains, s::Symbol) = c, [s]

@recipe function f(
    chains::Chains, i::Int;
    colordim = :chain,
    barbounds = (-Inf, Inf),
    maxlag = nothing,
    append_chains = false
)
    st = get(plotattributes, :seriestype, :traceplot)
    c = append_chains || st == :pooleddensity ? pool_chain(chains) : chains

    if colordim == :parameter
        title --> "Chain $(MCMCChains.chains(c)[i])"
        label --> string.(names(c))
        val = c.value[:, :, i]
    elseif colordim == :chain
        title --> string(names(c)[i])
        label --> map(x -> "Chain $x", MCMCChains.chains(c))
        val = c.value[:, i, :]
    else
        throw(ArgumentError("`colordim` must be one of `:chain` or `:parameter`"))
    end

    if st == :mixeddensity || st == :pooleddensity
        discrete = indiscretesupport(c, barbounds)
        st = if colordim == :chain
            discrete[i] ? :histogram : :density
        else
            # NOTE: It might make sense to overlay histograms and density plots here.
            :density
        end
        seriestype := st
    end

    if st == :autocorplot
        lags = 0:(maxlag === nothing ? round(Int, 10 * log10(length(range(c)))) : maxlag)
        ac = autocor(c; sections = nothing, lags = lags)
        ac_mat = convert(Array, ac)
        val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :]
        _AutocorPlot(lags, val)
    elseif st  supportedplots
        translationdict[st](c, val)
    else
        range(c), val
    end
end

@recipe function f(p::_DensityPlot)
    xaxis --> "Sample value"
    yaxis --> "Density"
    trim --> true
    [collect(skipmissing(p.val[:,k])) for k in 1:size(p.val, 2)]
end

@recipe function f(p::_HistogramPlot)
    xaxis --> "Sample value"
    yaxis --> "Frequency"
    fillalpha --> 0.7
    bins --> 25
    trim --> true
    [collect(skipmissing(p.val[:,k])) for k in 1:size(p.val, 2)]
end

@recipe function f(p::_MeanPlot)
    seriestype := :path
    xaxis --> "Iteration"
    yaxis --> "Mean"
    range(p.c), cummean(p.val)
end

@recipe function f(p::_AutocorPlot)
    seriestype := :path
    xaxis --> "Lag"
    yaxis --> "Autocorrelation"
    p.lags, p.val
end

@recipe function f(p::_TracePlot)
    seriestype := :path
    xaxis --> "Iteration"
    yaxis --> "Sample value"
    range(p.c), p.val
end

@recipe function f(
    chains::Chains,
    parameters::AbstractVector{Symbol};
    colordim = :chain
)
    colordim != :chain &&
        error("Symbol names are interpreted as parameter names, only compatible with ",
              "`colordim = :chain`")

    ret = indexin(parameters, names(chains))
    any(y === nothing for y in ret) && error("Parameter not found")

    return chains, Int.(ret)
end

@recipe function f(
    chains::Chains,
    parameters::AbstractVector{<:Integer} = Int[];
    sections = _default_sections(chains),
    width = 500,
    height = 250,
    colordim = :chain,
    append_chains = false
)
    _chains = isempty(parameters) ? Chains(chains, _clean_sections(chains, sections)) : chains
    c = append_chains ? pool_chain(_chains) : _chains
    ptypes = get(plotattributes, :seriestype, (:traceplot, :mixeddensity))
    ptypes = ptypes isa Symbol ? (ptypes,) : ptypes
    @assert all(ptype -> ptype  supportedplots, ptypes)
    ntypes = length(ptypes)
    nrows, nvars, nchains = size(c)
    isempty(parameters) && (parameters = colordim == :chain ? (1:nvars) : (1:nchains))
    N = length(parameters)

    if :corner  ptypes
        size --> (ntypes*width, N*height)
        legend --> false

        multiple_plots = N * ntypes > 1
        if multiple_plots
            layout := (N, ntypes)
        end

        i = 0
        for par in parameters
            for ptype in ptypes
                i += 1

                @series begin
                    if multiple_plots
                        subplot := i
                    end
                    colordim := colordim
                    seriestype := ptype
                    c, par
                end
            end
        end
    else
        ntypes > 1 && error(":corner is not compatible with multiple seriestypes")
        Corner(c, names(c)[parameters])
    end
end

struct Corner
    c
    parameters
end

@recipe function f(corner::Corner)
    label --> permutedims(corner.parameters)
    compact --> true
    size --> (600, 600)
    ar = collect(Array(corner.c.value[:, corner.parameters,i]) for i in chains(corner.c))
    RecipesBase.recipetype(:cornerplot, vcat(ar...))
end

function _compute_plot_data(
    i::Integer,
    chains::Chains,
    par_names::AbstractVector{Symbol},
    hpd_val = [0.05, 0.2],
    q = [0.1, 0.9],
    spacer = 0.4,
    _riser = 0.2,
    barbounds = (-Inf, Inf),
    show_mean = true,
    show_median = true,
    show_qi = false,
    show_hpdi = true,
    fill_q = true,
    fill_hpd = false,
    ordered = false
)

    chain_dic = Dict(zip(quantile(chains)[:,1], quantile(chains)[:,4]))
    sorted_chain = sort(collect(zip(values(chain_dic), keys(chain_dic))))
    sorted_par = [sorted_chain[i][2] for i in 1:length(par_names)]
    par = (ordered ? sorted_par : par_names)
    hpdi = sort(hpd_val)

    chain_sections = MCMCChains.group(chains, Symbol(par[i]))
    chain_vec = vec(chain_sections.value.data)
    lower_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j]).nt.lower
        for j in 1:length(hpdi)]
    upper_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j]).nt.upper
        for j in 1:length(hpdi)]
    h = _riser + spacer*(i-1)
    qs = quantile(chain_vec, q)
    k_density = kde(chain_vec)
    if fill_hpd
        x_int = filter(x -> lower_hpd[1][1] <= x <= upper_hpd[1][1], k_density.x)
        val = pdf(k_density, x_int) .+ h
    elseif fill_q
        x_int = filter(x -> qs[1] <= x <= qs[2], k_density.x)
        val = pdf(k_density, x_int) .+ h
    else
        x_int = k_density.x
        val = k_density.density .+ h
    end
    chain_med = median(chain_vec)
    chain_mean = mean(chain_vec)
    min = minimum(k_density.density .+ h)
    q_int = (show_qi ? [qs[1], chain_med, qs[2]] : [chain_med])

    return par, hpdi, lower_hpd, upper_hpd, h, qs, k_density, x_int, val, chain_med,
        chain_mean, min, q_int
end

@recipe function f(
    p::RidgelinePlot;
    hpd_val = [0.05, 0.2],
    q = [0.1, 0.9],
    spacer = 0.5,
    _riser = 0.2,
    show_mean = true,
    show_median = true,
    show_qi = false,
    show_hpdi = true,
    fill_q = true,
    fill_hpd = false,
    ordered = false
)

    chn = p.args[1]
    par_names = p.args[2]

    for i in 1:length(par_names)
        par, hpdi, lower_hpd, upper_hpd, h, qs, k_density, x_int, val, chain_med, chain_mean,
            min, q_int = _compute_plot_data(i, chn, par_names, hpd_val, q, spacer, _riser,
            show_mean, show_median, show_qi, show_hpdi, fill_q, fill_hpd, ordered)

        yticks --> (length(par_names) > 1 ?
            (_riser .+ ((1:length(par_names)) .- 1) .* spacer, string.(par)) : :default)
        yaxis --> (length(par_names) > 1 ? "Parameters" : "Density" )
        @series begin
            seriestype := :hline
            label := nothing
            linecolor := "#BBBBBB"
            linewidth --> 1.2
            [h]
        end
        @series begin
            seriestype := :path
            label := nothing
            fillrange --> min
            fillalpha --> 0.8
            x_int, val
        end
        @series begin
            seriestype := :path
            label := nothing
            linecolor --> "#000000"
            k_density.x, k_density.density .+ h
        end
        @series begin
            seriestype := :path
            label --> (show_mean ? (i == 1 ? "Mean" : nothing) : nothing)
            linecolor --> "dark red"
            linewidth --> (show_mean ? 1.2 : 0)
            [chain_mean, chain_mean], [min, min + pdf(k_density, chain_mean)]
        end
        @series begin
            seriestype := :path
            label --> (show_median ? (i == 1 ? "Median" : nothing) : nothing)
            linecolor --> "#000000"
            linewidth --> (show_median ? 1.2 : 0)
            [chain_med, chain_med], [min, min + pdf(k_density, chain_med)]
        end
        @series begin
            seriestype := :scatter
            label := (show_qi ? (i == 1 ? "Q$(q[1]), Q$(q[2])" : nothing) : nothing)
            markershape --> (show_qi ? :diamond : :circle)
            markercolor --> "#000000"
            markersize --> (show_qi ? 2 : 0)
            q_int, [h]
        end
        @series begin
            seriestype := :path
            label := nothing
            linecolor := "#000000"
            linewidth --> (show_qi ? 1.2 : 0)
            [qs[1], qs[2]], [h, h]
        end
        @series begin
            seriestype := :path
            label := (show_hpdi ? (i == 1 ? "$(Integer((1-hpdi[1])*100))% HPDI" : nothing)
                : nothing)
            linewidth --> (show_hpdi ? 2 : 0)
            seriesalpha --> 0.80
            linecolor --> :darkblue
            [lower_hpd[1][1], upper_hpd[1][1]], [h, h]
        end
    end
end

@recipe function f(
    p::ForestPlot;
    hpd_val = [0.05, 0.2],
    q = [0.1, 0.9],
    spacer = 0.5,
    _riser = 0.2,
    show_mean = true,
    show_median = true,
    show_qi = false,
    show_hpdi = true,
    fill_q = true,
    fill_hpd = false,
    ordered = false
)

    chn = p.args[1]
    par_names = p.args[2]

    for i in 1:length(par_names)
        par, hpdi, lower_hpd, upper_hpd, h, qs, k_density, x_int, val, chain_med, chain_mean,
        min, q_int = _compute_plot_data(i, chn, par_names, hpd_val, q, spacer, _riser,
        show_mean, show_median, show_qi, show_hpdi, fill_q, fill_hpd, ordered)

        yticks --> (length(par_names) > 1 ?
            (_riser .+ ((1:length(par_names)) .- 1) .* spacer, string.(par)) : :default)
        yaxis --> (length(par_names) > 1 ? "Parameters" : "Density" )

        for j in 1:length(hpdi)
            @series begin
                seriestype := :path
                label := (show_hpdi ?
                    (i == 1 ? "$(Integer((1-hpdi[j])*100))% HPDI" : nothing) : nothing)
                linecolor --> j
                linewidth --> (show_hpdi ? 1.5*j : 0)
                seriesalpha --> 0.80
                [lower_hpd[j][1], upper_hpd[j][1]], [h, h]
            end
        end
        @series begin
            seriestype := :scatter
            label := (show_median ? (i == 1 ? "Median" : nothing) : nothing)
            markershape --> :diamond
            markercolor --> "#000000"
            markersize --> (show_median ? length(hpdi) : 0)
            [chain_med], [h]
        end
        @series begin
            seriestype := :scatter
            label := (show_mean ? (i == 1 ? "Mean" : nothing) : nothing)
            markershape --> :circle
            markercolor --> :gray
            markersize --> (show_mean ? length(hpdi) : 0)
            [chain_mean], [h]
        end
        @series begin
            seriestype := :scatter
            label := (show_qi ? (i == 1 ? "Q1 = $(q[1]), Q3 = $(q[2])" : nothing) : nothing)
            markershape --> (show_qi ? :diamond : :circle)
            markercolor --> "#000000"
            markersize --> (show_qi ? 2 : 0)
            q_int, [h]
        end
        @series begin
            seriestype := :path
            label := nothing
            linecolor := "#000000"
            linewidth --> (show_qi ? 1.2 : 0.0)
            [qs[1], qs[2]], [h, h]
        end
    end
end

@recipe function f(
    yobs_data,
    ypred_data::Chains;
    yvar_name::AbstractVector{Symbol} = [],
    plot_type = :density,
    predictive_check = :posterior,
    n_samples::Int = 50
    )

    st = get(plotattributes, :seriestype, :traceplot)

    if st == :ppcplot
        N = n_samples <= size(ypred_data)[1] ? n_samples : size(ypred_data)[1]
        index = sample(1:size(ypred_data)[1], N, replace = false, ordered = true)

        if ndims(yobs_data) == 1
            y_obs = plot_type == :cumulative ? ecdf(vec(yobs_data)) : vec(yobs_data)
            predictions = ypred_data.value.data[index,:,:]
            ymean_pred = (plot_type == :cumulative
                ? ecdf(vec(mean(ypred_data.value.data, dims = 1)))
                : vec(mean(ypred_data.value.data, dims = 1)))

            if plot_type == :density || plot_type == :cumulative
                if predictive_check == :posterior
                    title --> "Posterior predictive check"
                elseif predictive_check == :prior
                    title --> "Prior predictive check"
                else
                    throw(ArgumentError("`predictive_check` must be one of `prior` or `posterior`"))
                end
                for i in 1:N
                    y_pred = (plot_type == :cumulative ? ecdf(vec(predictions[i,:,:]))
                        : vec(predictions[i,:,:]))
                    ypred_label = (isempty(yvar_name) ? (i == 1 ? "y pred" : nothing)
                        : (i == 1 ? "$(yvar_name[1]) pred" : nothing))
                    @series begin
                        seriestype := :density
                        seriesalpha --> 0.3
                        linecolor --> "#BBBBBB"
                        label --> ypred_label
                        y_pred
                    end
                end
                @series begin
                    seriestype := :density
                    label --> (isempty(yvar_name) ? "y obs" : "$(yvar_name[1]) obs")
                    y_obs
                end
                @series begin
                    seriestype := :density
                    label --> (isempty(yvar_name) ? "y mean" : "$(yvar_name[1]) mean")
                    ymean_pred
                end

            elseif plot_type == :histogram
                layout --> N + 2
                k = 1
                @series begin
                    subplot := k
                    seriestype := :histogram
                    label --> (isempty(yvar_name) ? "y obs" : "$(yvar_name[1]) obs")
                    y_obs
                end
                k = 2
                @series begin
                    subplot := k
                    seriestype := :histogram
                    label --> (isempty(yvar_name) ? "y mean" : "$(yvar_name[1]) mean")
                    ymean_pred
                end
                for i in 1:N
                    y_pred = predictions[i,:,:]
                    @series begin
                        subplot := k + i
                        seriestype := :histogram
                        label --> nothing
                        y_pred
                    end
                end
            else
                throw(ArgumentError("`plot_type` must be one of `:density`, `:cumulative` or `histogram`"))
            end

        elseif ndims(yobs_data) > 1
            n_yval = size(yobs_data)[1]
            n_yvar = size(yobs_data)[2]
            mean_arr = reshape(mean(ypred_data.value.data, dims = 1), (n_yval, n_yvar))
            k = 0
            for j in 1:n_yvar
                sections = MCMCChains.group(ypred_data, Symbol(yvar_name[j]))
                predictions = sections.value.data[index,:,:]
                y_obs = (plot_type == :cumulative ? ecdf(vec(yobs_data[:,j]))
                    : vec(yobs_data[:,j]))
                ymean_pred = (plot_type == :cumulative ? ecdf(vec(mean_arr[:,j]))
                    : vec(mean_arr[:,j]))

                if plot_type == :density || plot_type == :cumulative
                    k += 1
                    layout --> (1, n_yvar)
                    if predictive_check == :posterior
                        title --> "Posterior predictive check"
                    elseif predictive_check == :prior
                        title --> "Prior predictive check"
                    else
                        throw(ArgumentError("`predictive_check` must be one of `prior` or `posterior`"))
                    end

                    for i in 1:N
                        y_pred = (plot_type == :cumulative ? ecdf(vec(predictions[i,:,:]))
                            : vec(predictions[i,:,:]))
                        @series begin
                            subplot := k
                            seriestype := :density
                            seriesalpha --> 0.3
                            linecolor --> "#BBBBBB"
                            label --> (i == 1 ? "$(yvar_name[j]) pred" : nothing)
                            y_pred
                        end
                    end
                    @series begin
                        subplot := k
                        seriestype := :density
                        label --> "$(yvar_name[j]) obs"
                        y_obs
                    end
                    @series begin
                        subplot := k
                        seriestype := :density
                        label --> "$(yvar_name[j]) mean"
                        ymean_pred
                    end

                elseif plot_type == :histogram
                    subplot := k
                    layout --> N + 2
                    h = 1
                    @series begin
                        subplot := h
                        seriestype := :histogram
                        label --> "$(yvar_name[j]) obs"
                        y_obs
                    end
                    h = 2
                    @series begin
                        subplot := h
                        seriestype := :histogram
                        label --> "$(yvar_name[j]) mean"
                        ymean_pred
                    end
                    for i in 1:N
                        y_pred = predictions[i,:,:]
                        @series begin
                            subplot := h + i
                            seriestype := :histogram
                            label --> nothing
                            y_pred
                        end
                    end

                else
                    throw(ArgumentError("`plot_type` must be one of `:density`, `:cumulative` or `:histogram`"))
                end
            end
        else
            throw(ArgumentError("Observed data must have `dim > 1`"))
        end
    else

    end
end

@recipe function f(p::_PPCPlot)
    p.y_obs, p.y_pred
end                                                   

Copy link

codecov bot commented Sep 11, 2025

Codecov Report

❌ Patch coverage is 95.30516% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 87.57%. Comparing base (6909f74) to head (8a9d8cb).

Files with missing lines Patch % Lines
src/plot.jl 95.30% 10 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #319      +/-   ##
==========================================
+ Coverage   86.13%   87.57%   +1.43%     
==========================================
  Files          20       20              
  Lines        1147     1360     +213     
==========================================
+ Hits          988     1191     +203     
- Misses        159      169      +10     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants