Skip to content

Commit 39b5e66

Browse files
author
Carlos Parada
authored
Lazy-compute-ESS (#52)
* add PSIS for vectors * use entropy-based ESS * Add option to avoid calculating ESS * add log_weights
1 parent 55e0066 commit 39b5e66

File tree

10 files changed

+95
-84
lines changed

10 files changed

+95
-84
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ jobs:
1717
version:
1818
- '1' # latest stable 1.x release of Julia
1919
- '1.6' # oldest supported version
20-
- 'nightly'
2120
os:
2221
- ubuntu-latest
2322
arch:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ version = "0.6.6"
66
[deps]
77
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
88
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
9-
Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0"
109
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1110
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1211
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
12+
Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826"
1313
NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f"
1414
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1515
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

src/AbstractCV.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
using AxisKeys
33
using PrettyTables
44

5-
export AbstractCVMethod, AbstractCV
5+
# export AbstractCVMethod, AbstractCV
66

7-
const POINTWISE_LABELS = (:cv_elpd, :naive_lpd, :p_eff, :ess, :pareto_k)
87
const CV_DESC = """
98
# Fields
109
@@ -73,12 +72,12 @@ An abstract type used in cross-validation.
7372
abstract type AbstractCV end
7473

7574

76-
"""
77-
AbstractCVMethod
75+
# """
76+
# AbstractCVMethod
7877

79-
An abstract type used to dispatch the correct method for cross validation.
80-
"""
81-
abstract type AbstractCVMethod end
78+
# An abstract type used to dispatch the correct method for cross validation.
79+
# """
80+
# abstract type AbstractCVMethod end
8281

8382

8483
##########################

src/ESS.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ export relative_eff, psis_ess, sup_ess
55

66
"""
77
relative_eff(
8-
sample::AbstractArray{Real, 3};
8+
sample::AbstractArray{<:Real, 3};
99
method=MCMCDiagnosticTools.FFTESSMethod()
1010
)
1111
@@ -16,7 +16,7 @@ by the nominal sample size.
1616
1717
- `sample::AbstractArray{<:Real, 3}`: An array of log-likelihood values.
1818
"""
19-
function relative_eff(sample::AbstractArray{<:Real, 3}; maxlag=size(sample, 2), kwargs...)
19+
function relative_eff(sample::AbstractArray{<:Real,3}; maxlag=size(sample, 2), kwargs...)
2020
dims = size(sample)
2121
post_sample_size = dims[2] * dims[3]
2222
ess_sample = permutedims(sample, [2, 1, 3])
@@ -60,13 +60,13 @@ end
6060

6161
"""
6262
function sup_ess(
63-
weights::AbstractVector{T},
63+
weights::AbstractMatrix{T},
6464
r_eff::AbstractVector{T}
6565
) -> AbstractVector
6666
6767
Calculate the supremum-based effective sample size of a PSIS sample, i.e. the inverse of the
68-
maximum weight. This measure is more trustworthy than the `ess` from `psis_ess`. It uses the
69-
L-∞ norm.
68+
maximum weight. This measure is more sensitive than the `ess` from `psis_ess`, but also
69+
much more variable. It uses the L-∞ norm.
7070
7171
# Arguments
7272
- `weights`: A set of importance sampling weights derived from PSIS.

src/GPD.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ using Tullio
55

66

77
"""
8-
gpdfit(
9-
sample::AbstractVector{T<:Real};
8+
gpd_fit(
9+
sample::AbstractVector{T<:Real},
10+
r_eff::T = 1;
1011
wip::Bool=true,
1112
min_grid_pts::Integer=30,
1213
sort_sample::Bool=false
@@ -29,12 +30,13 @@ generalized Pareto distribution (GPD), assuming the location parameter is 0.
2930
Estimation method taken from Zhang, J. and Stephens, M.A. (2009). The parameter ξ is the
3031
negative of k.
3132
"""
32-
function gpdfit(
33-
sample::AbstractVector{T};
33+
function gpd_fit(
34+
sample::AbstractVector{T},
35+
r_eff::T=1;
3436
wip::Bool=true,
3537
min_grid_pts::Integer=30,
3638
sort_sample::Bool=false,
37-
) where {T <: Real}
39+
) where T<:Real
3840

3941
len = length(sample)
4042
# sample must be sorted, but we can skip if sample is already sorted
@@ -70,7 +72,7 @@ function gpdfit(
7072

7173
# Drag towards .5 to reduce variance for small len
7274
if wip
73-
@fastmath ξ =* len + 0.5 * n_0) / (len + n_0)
75+
@fastmath ξ = (r_eff * ξ * len + 0.5 * n_0) / (r_eff * len + n_0)
7476
end
7577

7678
return ξ, σ

src/ImportanceSampling.jl

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ double check it is correct.
99
const MIN_TAIL_LEN = 5 # Minimum size of a tail for PSIS to give sensible answers
1010
const SAMPLE_SOURCES = ["mcmc", "vi", "other"]
1111

12-
export psis, psis!, PsisLoo, PsisLooMethod, Psis
12+
export psis, psis!, Psis
1313

1414

1515
###########################
@@ -24,9 +24,7 @@ A struct containing the results of Pareto-smoothed importance sampling.
2424
2525
# Fields
2626
27-
- `log_weights`: A vector of smoothed and truncated but *unnormalized* importance sampling
28-
weights.
29-
- `weights`: A lazy
27+
- `weights`: A vector of smoothed, truncated, and normalized importance sampling weights.
3028
- `pareto_k`: Estimates of the shape parameter `k` of the generalized Pareto distribution.
3129
- `ess`: Estimated effective sample size for each LOO evaluation, based on the variance of
3230
the weights.
@@ -39,21 +37,38 @@ A struct containing the results of Pareto-smoothed importance sampling.
3937
- `data_size`: How many data points were used for PSIS.
4038
"""
4139
struct Psis{
42-
RealType <: Real,
43-
AT <: AbstractArray{RealType, 3},
44-
VT <: AbstractVector{RealType},
40+
R <: Real,
41+
AT <: AbstractArray{R, 3},
42+
VT <: AbstractVector{R}
4543
}
4644
weights::AT
4745
pareto_k::VT
4846
ess::VT
4947
sup_ess::VT
5048
r_eff::VT
51-
tail_len::Vector{Int}
49+
tail_len::AbstractVector{Int}
5250
posterior_sample_size::Int
5351
data_size::Int
5452
end
5553

5654

55+
function Base.getproperty(psis_obj::Psis, k::Symbol)
56+
if k === :log_weights
57+
return log.(getfield(psis_obj, :weights))
58+
else
59+
return getfield(psis_obj, k)
60+
end
61+
end
62+
63+
64+
function Base.propertynames(psis_object::Psis)
65+
return (
66+
fieldnames(typeof(psis_object))...,
67+
:log_weights,
68+
)
69+
end
70+
71+
5772
function Base.show(io::IO, ::MIME"text/plain", psis_object::Psis)
5873
table = hcat(psis_object.pareto_k, psis_object.ess, psis_object.sup_ess)
5974
post_samples = psis_object.posterior_sample_size
@@ -79,14 +94,16 @@ end
7994
"""
8095
psis(
8196
log_ratios::AbstractArray{T<:Real},
82-
r_eff::AbstractVector;
97+
r_eff::AbstractVector{T};
8398
source::String="mcmc"
8499
) -> Psis
85100
86101
Implements Pareto-smoothed importance sampling (PSIS).
87102
88103
# Arguments
104+
89105
## Positional Arguments
106+
90107
- `log_ratios::AbstractArray`: A 2d or 3d array of (unnormalized) importance ratios on the
91108
log scale. Indices must be ordered as `[data, step, chain]`. The chain index can be left
92109
off if there is only one chain, or if keyword argument `chain_index` is provided.
@@ -98,15 +115,17 @@ Implements Pareto-smoothed importance sampling (PSIS).
98115
- `source::String="mcmc"`: A string or symbol describing the source of the sample being
99116
used. If `"mcmc"`, adjusts ESS for autocorrelation. Otherwise, samples are assumed to be
100117
independent. Currently permitted values are $SAMPLE_SOURCES.
118+
- `calc_ess::Bool=true`: If `false`, do not calculate ESS diagnostics. Attempting to
119+
access ESS diagnostics will return an empty list.
101120
102121
See also: [`relative_eff`]@ref, [`psis_loo`]@ref, [`psis_ess`]@ref.
103122
"""
104123
function psis(
105-
log_ratios::AbstractArray{<:Real, 3};
106-
r_eff::AbstractVector{<:Real}=similar(log_ratios, 0),
124+
log_ratios::AbstractArray{T, 3};
125+
r_eff::AbstractVector{T}=similar(log_ratios, 0),
107126
source::Union{AbstractString, Symbol}="mcmc",
108-
log_weights::Bool=true
109-
)
127+
calc_ess::Bool = true
128+
) where T <: Real
110129

111130
source = lowercase(String(source))
112131
dims = size(log_ratios)
@@ -115,27 +134,35 @@ function psis(
115134
post_sample_size = dims[2] * dims[3]
116135

117136
# Reshape to matrix (easier to deal with)
118-
log_ratios = reshape(log_ratios, data_size, post_sample_size)
119-
r_eff = _generate_r_eff(log_ratios, dims, r_eff, source)
120-
_check_input_validity_psis(reshape(log_ratios, dims), r_eff)
121-
weights = @. exp(log_ratios - $maximum(log_ratios; dims=2))
137+
log_ratios_mat = reshape(log_ratios, data_size, post_sample_size)
138+
r_eff = _generate_r_eff(log_ratios_mat, dims, r_eff, source)
139+
_check_input_validity_psis(log_ratios, r_eff)
140+
weights = similar(log_ratios)
141+
weights_mat = reshape(weights, data_size, post_sample_size)
142+
@. weights = exp(log_ratios - $maximum(log_ratios; dims=(2,3)))
143+
122144

123-
tail_length = Vector{Int}(undef, data_size)
145+
tail_length = similar(r_eff, Int)
124146
ξ = similar(r_eff)
125147
@inbounds Threads.@threads for i in eachindex(tail_length)
126148
tail_length[i] = _def_tail_length(post_sample_size, r_eff[i])
127-
ξ[i] = @views psis!(weights[i, :], tail_length[i])
149+
ξ[i] = @views psis!(weights_mat[i, :], r_eff[i]; tail_length=tail_length[i])
128150
end
129151

130-
@tullio norm_const[i] := weights[i, j]
152+
@tullio norm_const[i] := weights[i, j, k]
131153
@. weights = weights / norm_const
132-
ess = psis_ess(weights, r_eff)
133-
inf_ess = sup_ess(weights, r_eff)
134154

135-
weights = reshape(weights, dims)
155+
156+
if calc_ess
157+
ess = psis_ess(weights_mat, r_eff)
158+
inf_ess = sup_ess(weights_mat, r_eff)
159+
else
160+
ess = similar(weights_mat, 0)
161+
inf_ess = similar(weights_mat, 0)
162+
end
136163

137164
return Psis(
138-
weights,
165+
weights,
139166
ξ,
140167
ess,
141168
inf_ess,
@@ -193,10 +220,11 @@ log-weights.
193220
Unlike the methods for arrays, `psis!` performs no checks to make sure the input values are
194221
valid.
195222
"""
196-
function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer;
223+
function psis!(is_ratios::AbstractVector{T}, r_eff::T=one(T);
224+
tail_length::Integer = _def_tail_length(length(is_ratios), r_eff),
197225
log_weights::Bool=false
198-
)
199-
226+
) where T<:Real
227+
200228
len = length(is_ratios)
201229
tail_start = len - tail_length + 1 # index of smallest tail value
202230

@@ -213,7 +241,7 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer;
213241

214242
# Get value just before the tail starts:
215243
cutoff = is_ratios[tail_start - 1]
216-
ξ = _psis_smooth_tail!(tail, cutoff)
244+
ξ = _psis_smooth_tail!(tail, cutoff, r_eff)
217245

218246
# truncate at max of raw weights (1 after scaling)
219247
clamp!(is_ratios, 0, 1)
@@ -228,38 +256,33 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer;
228256
end
229257

230258

231-
function psis!(is_ratios::AbstractVector{<:Real}, r_eff::Real=1)
232-
tail_length = _def_tail_length(length(is_ratios), r_eff)
233-
return psis!(is_ratios, tail_length)
234-
end
235-
236-
237259
"""
238260
_def_tail_length(log_ratios::AbstractVector, r_eff::Real) -> Integer
239261
240262
Define the tail length as in Vehtari et al. (2019), with the small addition that the tail
241263
must a multiple of `32*bit_length` (which improves performance).
242264
"""
243-
function _def_tail_length(length::Integer, r_eff::Real=1)
265+
function _def_tail_length(length::Integer, r_eff::Real=one(T))
244266
return min(cld(length, 5), ceil(3 * sqrt(length / r_eff))) |> Int
245267
end
246268

247269

248270
"""
249-
_psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T<:Real} -> ξ::T
271+
_psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T, r_eff::T=1) where {T<:Real}
272+
-> ξ::T
250273
251274
Takes an *already sorted* vector of observations from the tail and smooths it *in place*
252275
with PSIS before returning shape parameter `ξ`.
253276
"""
254-
function _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T <: Real}
277+
function _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T, r_eff::T=one(T)) where {T <: Real}
255278
len = length(tail)
256279
if any(isinf.(tail))
257280
return ξ = Inf
258281
else
259282
@. tail = tail - cutoff
260283

261284
# save time not sorting since tail is already sorted
262-
ξ, σ = gpdfit(tail)
285+
ξ, σ = gpd_fit(tail, r_eff)
263286
@. tail = gpd_quantile(($(1:len) - 0.5) / len, ξ, σ) + cutoff
264287
end
265288
return ξ

src/LeaveOneOut.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,27 @@ using Statistics
55
using Printf
66
using Tullio
77

8-
export loo, psis_loo, loo_from_psis
8+
export loo, psis_loo, loo_from_psis, PsisLoo
99

1010

1111
#####################
1212
###### STRUCTS ######
1313
#####################
1414

1515

16-
"""
17-
PsisLooMethod
16+
# """
17+
# PsisLooMethod
1818

19-
Use Pareto-smoothed importance sampling together with leave-one-out cross validation to
20-
estimate the out-of-sample predictive accuracy.
21-
"""
22-
struct PsisLooMethod <: AbstractCVMethod end
19+
# Use Pareto-smoothed importance sampling together with leave-one-out cross validation to
20+
# estimate the out-of-sample predictive accuracy.
21+
# """
22+
# struct PsisLooMethod <: AbstractCVMethod end
2323

2424

2525
"""
2626
PsisLoo <: AbstractCV
2727
28-
A struct containing the results of leave-one-out cross validation using Pareto
28+
A struct containing the results of leave-one-out cross validation computed with Pareto
2929
smoothed importance sampling.
3030
3131
$CV_DESC
@@ -71,17 +71,17 @@ end
7171

7272

7373
"""
74-
function loo(args...; method=PsisLooMethod(), kwargs...) -> PsisLoo
74+
function loo(args...; kwargs...) -> PsisLoo
7575
76-
Compute the approximate leave-one-out cross-validation score using the specified method.
76+
Compute an approximate leave-one-out cross-validation score.
7777
7878
Currently, this function only serves to call `psis_loo`, but this could change in the
79-
future. The default methods or return type may change without warning; thus, we recommend
79+
future. The default methods or return type may change without warning, so we recommend
8080
using `psis_loo` instead if reproducibility is required.
8181
8282
See also: [`psis_loo`](@ref), [`PsisLoo`](@ref).
8383
"""
84-
function loo(args...; method=PsisLooMethod(), kwargs...)
84+
function loo(args...; kwargs...)
8585
return psis_loo(args...; kwargs...)
8686
end
8787

0 commit comments

Comments
 (0)