@@ -9,7 +9,7 @@ double check it is correct.
9
9
const MIN_TAIL_LEN = 5 # Minimum size of a tail for PSIS to give sensible answers
10
10
const SAMPLE_SOURCES = [" mcmc" , " vi" , " other" ]
11
11
12
- export psis, psis!, PsisLoo, PsisLooMethod, Psis
12
+ export psis, psis!, Psis
13
13
14
14
15
15
# ##########################
@@ -24,9 +24,7 @@ A struct containing the results of Pareto-smoothed importance sampling.
24
24
25
25
# Fields
26
26
27
- - `log_weights`: A vector of smoothed and truncated but *unnormalized* importance sampling
28
- weights.
29
- - `weights`: A lazy
27
+ - `weights`: A vector of smoothed, truncated, and normalized importance sampling weights.
30
28
- `pareto_k`: Estimates of the shape parameter `k` of the generalized Pareto distribution.
31
29
- `ess`: Estimated effective sample size for each LOO evaluation, based on the variance of
32
30
the weights.
@@ -39,21 +37,38 @@ A struct containing the results of Pareto-smoothed importance sampling.
39
37
- `data_size`: How many data points were used for PSIS.
40
38
"""
41
39
struct Psis{
42
- RealType <: Real ,
43
- AT <: AbstractArray{RealType , 3} ,
44
- VT <: AbstractVector{RealType} ,
40
+ R <: Real ,
41
+ AT <: AbstractArray{R , 3} ,
42
+ VT <: AbstractVector{R}
45
43
}
46
44
weights:: AT
47
45
pareto_k:: VT
48
46
ess:: VT
49
47
sup_ess:: VT
50
48
r_eff:: VT
51
- tail_len:: Vector {Int}
49
+ tail_len:: AbstractVector {Int}
52
50
posterior_sample_size:: Int
53
51
data_size:: Int
54
52
end
55
53
56
54
55
+ function Base. getproperty (psis_obj:: Psis , k:: Symbol )
56
+ if k === :log_weights
57
+ return log .(getfield (psis_obj, :weights ))
58
+ else
59
+ return getfield (psis_obj, k)
60
+ end
61
+ end
62
+
63
+
64
+ function Base. propertynames (psis_object:: Psis )
65
+ return (
66
+ fieldnames (typeof (psis_object))... ,
67
+ :log_weights ,
68
+ )
69
+ end
70
+
71
+
57
72
function Base. show (io:: IO , :: MIME"text/plain" , psis_object:: Psis )
58
73
table = hcat (psis_object. pareto_k, psis_object. ess, psis_object. sup_ess)
59
74
post_samples = psis_object. posterior_sample_size
79
94
"""
80
95
psis(
81
96
log_ratios::AbstractArray{T<:Real},
82
- r_eff::AbstractVector;
97
+ r_eff::AbstractVector{T} ;
83
98
source::String="mcmc"
84
99
) -> Psis
85
100
86
101
Implements Pareto-smoothed importance sampling (PSIS).
87
102
88
103
# Arguments
104
+
89
105
## Positional Arguments
106
+
90
107
- `log_ratios::AbstractArray`: A 2d or 3d array of (unnormalized) importance ratios on the
91
108
log scale. Indices must be ordered as `[data, step, chain]`. The chain index can be left
92
109
off if there is only one chain, or if keyword argument `chain_index` is provided.
@@ -98,15 +115,17 @@ Implements Pareto-smoothed importance sampling (PSIS).
98
115
- `source::String="mcmc"`: A string or symbol describing the source of the sample being
99
116
used. If `"mcmc"`, adjusts ESS for autocorrelation. Otherwise, samples are assumed to be
100
117
independent. Currently permitted values are $SAMPLE_SOURCES .
118
+ - `calc_ess::Bool=true`: If `false`, do not calculate ESS diagnostics. Attempting to
119
+ access ESS diagnostics will return an empty list.
101
120
102
121
See also: [`relative_eff`]@ref, [`psis_loo`]@ref, [`psis_ess`]@ref.
103
122
"""
104
123
function psis (
105
- log_ratios:: AbstractArray{<:Real , 3} ;
106
- r_eff:: AbstractVector{<:Real } = similar (log_ratios, 0 ),
124
+ log_ratios:: AbstractArray{T , 3} ;
125
+ r_eff:: AbstractVector{T } = similar (log_ratios, 0 ),
107
126
source:: Union{AbstractString, Symbol} = " mcmc" ,
108
- log_weights :: Bool = true
109
- )
127
+ calc_ess :: Bool = true
128
+ ) where T <: Real
110
129
111
130
source = lowercase (String (source))
112
131
dims = size (log_ratios)
@@ -115,27 +134,35 @@ function psis(
115
134
post_sample_size = dims[2 ] * dims[3 ]
116
135
117
136
# Reshape to matrix (easier to deal with)
118
- log_ratios = reshape (log_ratios, data_size, post_sample_size)
119
- r_eff = _generate_r_eff (log_ratios, dims, r_eff, source)
120
- _check_input_validity_psis (reshape (log_ratios, dims), r_eff)
121
- weights = @. exp (log_ratios - $ maximum (log_ratios; dims= 2 ))
137
+ log_ratios_mat = reshape (log_ratios, data_size, post_sample_size)
138
+ r_eff = _generate_r_eff (log_ratios_mat, dims, r_eff, source)
139
+ _check_input_validity_psis (log_ratios, r_eff)
140
+ weights = similar (log_ratios)
141
+ weights_mat = reshape (weights, data_size, post_sample_size)
142
+ @. weights = exp (log_ratios - $ maximum (log_ratios; dims= (2 ,3 )))
143
+
122
144
123
- tail_length = Vector {Int} (undef, data_size )
145
+ tail_length = similar (r_eff, Int )
124
146
ξ = similar (r_eff)
125
147
@inbounds Threads. @threads for i in eachindex (tail_length)
126
148
tail_length[i] = _def_tail_length (post_sample_size, r_eff[i])
127
- ξ[i] = @views psis! (weights [i, :], tail_length[i])
149
+ ξ[i] = @views psis! (weights_mat [i, :], r_eff[i]; tail_length = tail_length[i])
128
150
end
129
151
130
- @tullio norm_const[i] := weights[i, j]
152
+ @tullio norm_const[i] := weights[i, j, k ]
131
153
@. weights = weights / norm_const
132
- ess = psis_ess (weights, r_eff)
133
- inf_ess = sup_ess (weights, r_eff)
134
154
135
- weights = reshape (weights, dims)
155
+
156
+ if calc_ess
157
+ ess = psis_ess (weights_mat, r_eff)
158
+ inf_ess = sup_ess (weights_mat, r_eff)
159
+ else
160
+ ess = similar (weights_mat, 0 )
161
+ inf_ess = similar (weights_mat, 0 )
162
+ end
136
163
137
164
return Psis (
138
- weights,
165
+ weights,
139
166
ξ,
140
167
ess,
141
168
inf_ess,
@@ -193,10 +220,11 @@ log-weights.
193
220
Unlike the methods for arrays, `psis!` performs no checks to make sure the input values are
194
221
valid.
195
222
"""
196
- function psis! (is_ratios:: AbstractVector{<:Real} , tail_length:: Integer ;
223
+ function psis! (is_ratios:: AbstractVector{T} , r_eff:: T = one (T);
224
+ tail_length:: Integer = _def_tail_length (length (is_ratios), r_eff),
197
225
log_weights:: Bool = false
198
- )
199
-
226
+ ) where T <: Real
227
+
200
228
len = length (is_ratios)
201
229
tail_start = len - tail_length + 1 # index of smallest tail value
202
230
@@ -213,7 +241,7 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer;
213
241
214
242
# Get value just before the tail starts:
215
243
cutoff = is_ratios[tail_start - 1 ]
216
- ξ = _psis_smooth_tail! (tail, cutoff)
244
+ ξ = _psis_smooth_tail! (tail, cutoff, r_eff )
217
245
218
246
# truncate at max of raw weights (1 after scaling)
219
247
clamp! (is_ratios, 0 , 1 )
@@ -228,38 +256,33 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer;
228
256
end
229
257
230
258
231
- function psis! (is_ratios:: AbstractVector{<:Real} , r_eff:: Real = 1 )
232
- tail_length = _def_tail_length (length (is_ratios), r_eff)
233
- return psis! (is_ratios, tail_length)
234
- end
235
-
236
-
237
259
"""
238
260
_def_tail_length(log_ratios::AbstractVector, r_eff::Real) -> Integer
239
261
240
262
Define the tail length as in Vehtari et al. (2019), with the small addition that the tail
241
263
must a multiple of `32*bit_length` (which improves performance).
242
264
"""
243
- function _def_tail_length (length:: Integer , r_eff:: Real = 1 )
265
+ function _def_tail_length (length:: Integer , r_eff:: Real = one (T) )
244
266
return min (cld (length, 5 ), ceil (3 * sqrt (length / r_eff))) |> Int
245
267
end
246
268
247
269
248
270
"""
249
- _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T<:Real} -> ξ::T
271
+ _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T, r_eff::T=1) where {T<:Real}
272
+ -> ξ::T
250
273
251
274
Takes an *already sorted* vector of observations from the tail and smooths it *in place*
252
275
with PSIS before returning shape parameter `ξ`.
253
276
"""
254
- function _psis_smooth_tail! (tail:: AbstractVector{T} , cutoff:: T ) where {T <: Real }
277
+ function _psis_smooth_tail! (tail:: AbstractVector{T} , cutoff:: T , r_eff :: T = one (T) ) where {T <: Real }
255
278
len = length (tail)
256
279
if any (isinf .(tail))
257
280
return ξ = Inf
258
281
else
259
282
@. tail = tail - cutoff
260
283
261
284
# save time not sorting since tail is already sorted
262
- ξ, σ = gpdfit (tail)
285
+ ξ, σ = gpd_fit (tail, r_eff )
263
286
@. tail = gpd_quantile (($ (1 : len) - 0.5 ) / len, ξ, σ) + cutoff
264
287
end
265
288
return ξ
0 commit comments