-
Notifications
You must be signed in to change notification settings - Fork 29
Prior/posterior predictive check plots #319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
PaulinaMartin96
wants to merge
3
commits into
TuringLang:main
Choose a base branch
from
PaulinaMartin96:pm/ppc_plot
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
+648
−3
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Probably 4.16.0 since #310 is a bigger thing and probably won't have too much effect here. |
Backup of Paulina's implementation of PPC Plots: src/plot.jl
@shorthands meanplot
@shorthands autocorplot
@shorthands mixeddensity
@shorthands pooleddensity
@shorthands traceplot
@shorthands corner
@userplot RidgelinePlot
@userplot ForestPlot
@shorthands ppcplot
struct _TracePlot; c; val; end
struct _MeanPlot; c; val; end
struct _DensityPlot; c; val; end
struct _HistogramPlot; c; val; end
struct _AutocorPlot; lags; val; end
struct _PPCPlot; y_obs; y_pred; ymean_pred; end
# define alias functions for old syntax
const translationdict = Dict(
:traceplot => _TracePlot,
:meanplot => _MeanPlot,
:density => _DensityPlot,
:histogram => _HistogramPlot,
:autocorplot => _AutocorPlot,
:pooleddensity => _DensityPlot,
:ppcplot => _PPCPlot
)
const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :corner)
@recipe f(c::Chains, s::Symbol) = c, [s]
@recipe function f(
chains::Chains, i::Int;
colordim = :chain,
barbounds = (-Inf, Inf),
maxlag = nothing,
append_chains = false
)
st = get(plotattributes, :seriestype, :traceplot)
c = append_chains || st == :pooleddensity ? pool_chain(chains) : chains
if colordim == :parameter
title --> "Chain $(MCMCChains.chains(c)[i])"
label --> string.(names(c))
val = c.value[:, :, i]
elseif colordim == :chain
title --> string(names(c)[i])
label --> map(x -> "Chain $x", MCMCChains.chains(c))
val = c.value[:, i, :]
else
throw(ArgumentError("`colordim` must be one of `:chain` or `:parameter`"))
end
if st == :mixeddensity || st == :pooleddensity
discrete = indiscretesupport(c, barbounds)
st = if colordim == :chain
discrete[i] ? :histogram : :density
else
# NOTE: It might make sense to overlay histograms and density plots here.
:density
end
seriestype := st
end
if st == :autocorplot
lags = 0:(maxlag === nothing ? round(Int, 10 * log10(length(range(c)))) : maxlag)
ac = autocor(c; sections = nothing, lags = lags)
ac_mat = convert(Array, ac)
val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :]
_AutocorPlot(lags, val)
elseif st ∈ supportedplots
translationdict[st](c, val)
else
range(c), val
end
end
@recipe function f(p::_DensityPlot)
xaxis --> "Sample value"
yaxis --> "Density"
trim --> true
[collect(skipmissing(p.val[:,k])) for k in 1:size(p.val, 2)]
end
@recipe function f(p::_HistogramPlot)
xaxis --> "Sample value"
yaxis --> "Frequency"
fillalpha --> 0.7
bins --> 25
trim --> true
[collect(skipmissing(p.val[:,k])) for k in 1:size(p.val, 2)]
end
@recipe function f(p::_MeanPlot)
seriestype := :path
xaxis --> "Iteration"
yaxis --> "Mean"
range(p.c), cummean(p.val)
end
@recipe function f(p::_AutocorPlot)
seriestype := :path
xaxis --> "Lag"
yaxis --> "Autocorrelation"
p.lags, p.val
end
@recipe function f(p::_TracePlot)
seriestype := :path
xaxis --> "Iteration"
yaxis --> "Sample value"
range(p.c), p.val
end
@recipe function f(
chains::Chains,
parameters::AbstractVector{Symbol};
colordim = :chain
)
colordim != :chain &&
error("Symbol names are interpreted as parameter names, only compatible with ",
"`colordim = :chain`")
ret = indexin(parameters, names(chains))
any(y === nothing for y in ret) && error("Parameter not found")
return chains, Int.(ret)
end
@recipe function f(
chains::Chains,
parameters::AbstractVector{<:Integer} = Int[];
sections = _default_sections(chains),
width = 500,
height = 250,
colordim = :chain,
append_chains = false
)
_chains = isempty(parameters) ? Chains(chains, _clean_sections(chains, sections)) : chains
c = append_chains ? pool_chain(_chains) : _chains
ptypes = get(plotattributes, :seriestype, (:traceplot, :mixeddensity))
ptypes = ptypes isa Symbol ? (ptypes,) : ptypes
@assert all(ptype -> ptype ∈ supportedplots, ptypes)
ntypes = length(ptypes)
nrows, nvars, nchains = size(c)
isempty(parameters) && (parameters = colordim == :chain ? (1:nvars) : (1:nchains))
N = length(parameters)
if :corner ∉ ptypes
size --> (ntypes*width, N*height)
legend --> false
multiple_plots = N * ntypes > 1
if multiple_plots
layout := (N, ntypes)
end
i = 0
for par in parameters
for ptype in ptypes
i += 1
@series begin
if multiple_plots
subplot := i
end
colordim := colordim
seriestype := ptype
c, par
end
end
end
else
ntypes > 1 && error(":corner is not compatible with multiple seriestypes")
Corner(c, names(c)[parameters])
end
end
struct Corner
c
parameters
end
@recipe function f(corner::Corner)
label --> permutedims(corner.parameters)
compact --> true
size --> (600, 600)
ar = collect(Array(corner.c.value[:, corner.parameters,i]) for i in chains(corner.c))
RecipesBase.recipetype(:cornerplot, vcat(ar...))
end
function _compute_plot_data(
i::Integer,
chains::Chains,
par_names::AbstractVector{Symbol},
hpd_val = [0.05, 0.2],
q = [0.1, 0.9],
spacer = 0.4,
_riser = 0.2,
barbounds = (-Inf, Inf),
show_mean = true,
show_median = true,
show_qi = false,
show_hpdi = true,
fill_q = true,
fill_hpd = false,
ordered = false
)
chain_dic = Dict(zip(quantile(chains)[:,1], quantile(chains)[:,4]))
sorted_chain = sort(collect(zip(values(chain_dic), keys(chain_dic))))
sorted_par = [sorted_chain[i][2] for i in 1:length(par_names)]
par = (ordered ? sorted_par : par_names)
hpdi = sort(hpd_val)
chain_sections = MCMCChains.group(chains, Symbol(par[i]))
chain_vec = vec(chain_sections.value.data)
lower_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j]).nt.lower
for j in 1:length(hpdi)]
upper_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j]).nt.upper
for j in 1:length(hpdi)]
h = _riser + spacer*(i-1)
qs = quantile(chain_vec, q)
k_density = kde(chain_vec)
if fill_hpd
x_int = filter(x -> lower_hpd[1][1] <= x <= upper_hpd[1][1], k_density.x)
val = pdf(k_density, x_int) .+ h
elseif fill_q
x_int = filter(x -> qs[1] <= x <= qs[2], k_density.x)
val = pdf(k_density, x_int) .+ h
else
x_int = k_density.x
val = k_density.density .+ h
end
chain_med = median(chain_vec)
chain_mean = mean(chain_vec)
min = minimum(k_density.density .+ h)
q_int = (show_qi ? [qs[1], chain_med, qs[2]] : [chain_med])
return par, hpdi, lower_hpd, upper_hpd, h, qs, k_density, x_int, val, chain_med,
chain_mean, min, q_int
end
@recipe function f(
p::RidgelinePlot;
hpd_val = [0.05, 0.2],
q = [0.1, 0.9],
spacer = 0.5,
_riser = 0.2,
show_mean = true,
show_median = true,
show_qi = false,
show_hpdi = true,
fill_q = true,
fill_hpd = false,
ordered = false
)
chn = p.args[1]
par_names = p.args[2]
for i in 1:length(par_names)
par, hpdi, lower_hpd, upper_hpd, h, qs, k_density, x_int, val, chain_med, chain_mean,
min, q_int = _compute_plot_data(i, chn, par_names, hpd_val, q, spacer, _riser,
show_mean, show_median, show_qi, show_hpdi, fill_q, fill_hpd, ordered)
yticks --> (length(par_names) > 1 ?
(_riser .+ ((1:length(par_names)) .- 1) .* spacer, string.(par)) : :default)
yaxis --> (length(par_names) > 1 ? "Parameters" : "Density" )
@series begin
seriestype := :hline
label := nothing
linecolor := "#BBBBBB"
linewidth --> 1.2
[h]
end
@series begin
seriestype := :path
label := nothing
fillrange --> min
fillalpha --> 0.8
x_int, val
end
@series begin
seriestype := :path
label := nothing
linecolor --> "#000000"
k_density.x, k_density.density .+ h
end
@series begin
seriestype := :path
label --> (show_mean ? (i == 1 ? "Mean" : nothing) : nothing)
linecolor --> "dark red"
linewidth --> (show_mean ? 1.2 : 0)
[chain_mean, chain_mean], [min, min + pdf(k_density, chain_mean)]
end
@series begin
seriestype := :path
label --> (show_median ? (i == 1 ? "Median" : nothing) : nothing)
linecolor --> "#000000"
linewidth --> (show_median ? 1.2 : 0)
[chain_med, chain_med], [min, min + pdf(k_density, chain_med)]
end
@series begin
seriestype := :scatter
label := (show_qi ? (i == 1 ? "Q$(q[1]), Q$(q[2])" : nothing) : nothing)
markershape --> (show_qi ? :diamond : :circle)
markercolor --> "#000000"
markersize --> (show_qi ? 2 : 0)
q_int, [h]
end
@series begin
seriestype := :path
label := nothing
linecolor := "#000000"
linewidth --> (show_qi ? 1.2 : 0)
[qs[1], qs[2]], [h, h]
end
@series begin
seriestype := :path
label := (show_hpdi ? (i == 1 ? "$(Integer((1-hpdi[1])*100))% HPDI" : nothing)
: nothing)
linewidth --> (show_hpdi ? 2 : 0)
seriesalpha --> 0.80
linecolor --> :darkblue
[lower_hpd[1][1], upper_hpd[1][1]], [h, h]
end
end
end
@recipe function f(
p::ForestPlot;
hpd_val = [0.05, 0.2],
q = [0.1, 0.9],
spacer = 0.5,
_riser = 0.2,
show_mean = true,
show_median = true,
show_qi = false,
show_hpdi = true,
fill_q = true,
fill_hpd = false,
ordered = false
)
chn = p.args[1]
par_names = p.args[2]
for i in 1:length(par_names)
par, hpdi, lower_hpd, upper_hpd, h, qs, k_density, x_int, val, chain_med, chain_mean,
min, q_int = _compute_plot_data(i, chn, par_names, hpd_val, q, spacer, _riser,
show_mean, show_median, show_qi, show_hpdi, fill_q, fill_hpd, ordered)
yticks --> (length(par_names) > 1 ?
(_riser .+ ((1:length(par_names)) .- 1) .* spacer, string.(par)) : :default)
yaxis --> (length(par_names) > 1 ? "Parameters" : "Density" )
for j in 1:length(hpdi)
@series begin
seriestype := :path
label := (show_hpdi ?
(i == 1 ? "$(Integer((1-hpdi[j])*100))% HPDI" : nothing) : nothing)
linecolor --> j
linewidth --> (show_hpdi ? 1.5*j : 0)
seriesalpha --> 0.80
[lower_hpd[j][1], upper_hpd[j][1]], [h, h]
end
end
@series begin
seriestype := :scatter
label := (show_median ? (i == 1 ? "Median" : nothing) : nothing)
markershape --> :diamond
markercolor --> "#000000"
markersize --> (show_median ? length(hpdi) : 0)
[chain_med], [h]
end
@series begin
seriestype := :scatter
label := (show_mean ? (i == 1 ? "Mean" : nothing) : nothing)
markershape --> :circle
markercolor --> :gray
markersize --> (show_mean ? length(hpdi) : 0)
[chain_mean], [h]
end
@series begin
seriestype := :scatter
label := (show_qi ? (i == 1 ? "Q1 = $(q[1]), Q3 = $(q[2])" : nothing) : nothing)
markershape --> (show_qi ? :diamond : :circle)
markercolor --> "#000000"
markersize --> (show_qi ? 2 : 0)
q_int, [h]
end
@series begin
seriestype := :path
label := nothing
linecolor := "#000000"
linewidth --> (show_qi ? 1.2 : 0.0)
[qs[1], qs[2]], [h, h]
end
end
end
@recipe function f(
yobs_data,
ypred_data::Chains;
yvar_name::AbstractVector{Symbol} = [],
plot_type = :density,
predictive_check = :posterior,
n_samples::Int = 50
)
st = get(plotattributes, :seriestype, :traceplot)
if st == :ppcplot
N = n_samples <= size(ypred_data)[1] ? n_samples : size(ypred_data)[1]
index = sample(1:size(ypred_data)[1], N, replace = false, ordered = true)
if ndims(yobs_data) == 1
y_obs = plot_type == :cumulative ? ecdf(vec(yobs_data)) : vec(yobs_data)
predictions = ypred_data.value.data[index,:,:]
ymean_pred = (plot_type == :cumulative
? ecdf(vec(mean(ypred_data.value.data, dims = 1)))
: vec(mean(ypred_data.value.data, dims = 1)))
if plot_type == :density || plot_type == :cumulative
if predictive_check == :posterior
title --> "Posterior predictive check"
elseif predictive_check == :prior
title --> "Prior predictive check"
else
throw(ArgumentError("`predictive_check` must be one of `prior` or `posterior`"))
end
for i in 1:N
y_pred = (plot_type == :cumulative ? ecdf(vec(predictions[i,:,:]))
: vec(predictions[i,:,:]))
ypred_label = (isempty(yvar_name) ? (i == 1 ? "y pred" : nothing)
: (i == 1 ? "$(yvar_name[1]) pred" : nothing))
@series begin
seriestype := :density
seriesalpha --> 0.3
linecolor --> "#BBBBBB"
label --> ypred_label
y_pred
end
end
@series begin
seriestype := :density
label --> (isempty(yvar_name) ? "y obs" : "$(yvar_name[1]) obs")
y_obs
end
@series begin
seriestype := :density
label --> (isempty(yvar_name) ? "y mean" : "$(yvar_name[1]) mean")
ymean_pred
end
elseif plot_type == :histogram
layout --> N + 2
k = 1
@series begin
subplot := k
seriestype := :histogram
label --> (isempty(yvar_name) ? "y obs" : "$(yvar_name[1]) obs")
y_obs
end
k = 2
@series begin
subplot := k
seriestype := :histogram
label --> (isempty(yvar_name) ? "y mean" : "$(yvar_name[1]) mean")
ymean_pred
end
for i in 1:N
y_pred = predictions[i,:,:]
@series begin
subplot := k + i
seriestype := :histogram
label --> nothing
y_pred
end
end
else
throw(ArgumentError("`plot_type` must be one of `:density`, `:cumulative` or `histogram`"))
end
elseif ndims(yobs_data) > 1
n_yval = size(yobs_data)[1]
n_yvar = size(yobs_data)[2]
mean_arr = reshape(mean(ypred_data.value.data, dims = 1), (n_yval, n_yvar))
k = 0
for j in 1:n_yvar
sections = MCMCChains.group(ypred_data, Symbol(yvar_name[j]))
predictions = sections.value.data[index,:,:]
y_obs = (plot_type == :cumulative ? ecdf(vec(yobs_data[:,j]))
: vec(yobs_data[:,j]))
ymean_pred = (plot_type == :cumulative ? ecdf(vec(mean_arr[:,j]))
: vec(mean_arr[:,j]))
if plot_type == :density || plot_type == :cumulative
k += 1
layout --> (1, n_yvar)
if predictive_check == :posterior
title --> "Posterior predictive check"
elseif predictive_check == :prior
title --> "Prior predictive check"
else
throw(ArgumentError("`predictive_check` must be one of `prior` or `posterior`"))
end
for i in 1:N
y_pred = (plot_type == :cumulative ? ecdf(vec(predictions[i,:,:]))
: vec(predictions[i,:,:]))
@series begin
subplot := k
seriestype := :density
seriesalpha --> 0.3
linecolor --> "#BBBBBB"
label --> (i == 1 ? "$(yvar_name[j]) pred" : nothing)
y_pred
end
end
@series begin
subplot := k
seriestype := :density
label --> "$(yvar_name[j]) obs"
y_obs
end
@series begin
subplot := k
seriestype := :density
label --> "$(yvar_name[j]) mean"
ymean_pred
end
elseif plot_type == :histogram
subplot := k
layout --> N + 2
h = 1
@series begin
subplot := h
seriestype := :histogram
label --> "$(yvar_name[j]) obs"
y_obs
end
h = 2
@series begin
subplot := h
seriestype := :histogram
label --> "$(yvar_name[j]) mean"
ymean_pred
end
for i in 1:N
y_pred = predictions[i,:,:]
@series begin
subplot := h + i
seriestype := :histogram
label --> nothing
y_pred
end
end
else
throw(ArgumentError("`plot_type` must be one of `:density`, `:cumulative` or `:histogram`"))
end
end
else
throw(ArgumentError("Observed data must have `dim > 1`"))
end
else
end
end
@recipe function f(p::_PPCPlot)
p.y_obs, p.y_pred
end |
420f9b2
to
537514a
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #319 +/- ##
==========================================
+ Coverage 86.13% 87.57% +1.43%
==========================================
Files 20 20
Lines 1147 1360 +213
==========================================
+ Hits 988 1191 +203
- Misses 159 169 +10 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
"ppcplot" function was added for plotting prior/posterior predictive checks for one or more dependent variables. As
args
this function receivesyobs_data
, the observed data for dependet variables (a vector or matrix), andypred_data
, the posterior/prior predictive results (Chains
object). It plots the observed data, a sample of predictions and the predictions mean.As kwargs, this function receives:
yvar_name
(vector ofSymbol
) which contains the name of the dependent variables to be plotted,plot_type
which can take:density
,:cumulative
, and:histogram
as values,predictive_check
for plot titles and can be:prior
or:posterior
(default value is:posterior
)n_samples
which established the number o samples to be plotted (default value is 50, but when plotting it is redefined as the minimum between 50 and sample size in ypred_data).For more than one dependet variable in a single model,
yvar_name
must be provided and the order in which names variables appear must be the same as in the observed data matrix. This was done in order to separate predictions for every dependent variable, becausepredict
does not return predictions ordered by variable.The following is a working example for a model with one dependent variable
And for posterior predictive check
Plot_type = :density

Plot_type = :cumulative
Plot_type = :histogram

Aditionally, this is a working example for a model with two dependent variables