From 7781222dadbf2d618039ba0b2763f6561357f11c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Sep 2021 23:06:28 +0200 Subject: [PATCH 1/3] Store only necessary objects in fields --- src/ImportanceSampling.jl | 41 ++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/src/ImportanceSampling.jl b/src/ImportanceSampling.jl index d1a4e38..2191e4e 100644 --- a/src/ImportanceSampling.jl +++ b/src/ImportanceSampling.jl @@ -43,14 +43,41 @@ struct Psis{ } weights::AT pareto_k::VT - ess::VT - sup_ess::VT r_eff::VT tail_len::Vector{Int} - posterior_sample_size::Int - data_size::Int end +function Base.propertynames(psis_object::Psis) + return ( + fieldnames(typeof(psis_object))..., + :log_weights, + :ess, + :sup_ess, + :posterior_sample_size, + :data_size, + ) +end + +function Base.getproperty(psis_object::Psis, k::Symbol) + if k === :log_weights + return log.(getfield(psis_object, :weights)) + elseif k === :posterior_sample_size + weights = getfield(psis_object, :weights) + return size(weights, 2) * size(weights, 3) + elseif k === :data_size + return size(getfield(psis_object, :weights), 1) + elseif k === :ess + weights = getfield(psis_object, :weights) + r_eff = getfield(psis_object, :r_eff) + return psis_ess(reshape(weights, size(weights, 1), :), r_eff) + elseif k === :sup_ess + weights = getfield(psis_object, :weights) + r_eff = getfield(psis_object, :r_eff) + return sup_ess(reshape(weights, size(weights, 1), :), r_eff) + else + return getfield(psis_object, k) + end +end function Base.show(io::IO, ::MIME"text/plain", psis_object::Psis) table = hcat(psis_object.pareto_k, psis_object.ess, psis_object.sup_ess) @@ -136,14 +163,10 @@ function psis( weights = reshape(weights, dims) return Psis( - weights, + weights, ξ, - ess, - inf_ess, r_eff, tail_length, - post_sample_size, - data_size ) end From a6a468fdd6a12efc09f9b43ab53a2d01eb8f433a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Sep 2021 23:06:56 +0200 Subject: [PATCH 2/3] Avoid changing types when aliasing --- src/ImportanceSampling.jl | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/ImportanceSampling.jl b/src/ImportanceSampling.jl index 2191e4e..0f02c35 100644 --- a/src/ImportanceSampling.jl +++ b/src/ImportanceSampling.jl @@ -139,28 +139,25 @@ function psis( post_sample_size = dims[2] * dims[3] # Reshape to matrix (easier to deal with) - log_ratios = reshape(log_ratios, data_size, post_sample_size) - r_eff = _generate_r_eff(log_ratios, dims, r_eff, source) + log_ratios_mat = reshape(log_ratios, data_size, post_sample_size) + r_eff = _generate_r_eff(log_ratios_mat, dims, r_eff, source) weights = similar(log_ratios) + weights_mat = reshape(weights, data_size, post_sample_size) # Shift ratios by maximum to prevent overflow - @. weights = exp(log_ratios - $maximum(log_ratios; dims=2)) + weights .= exp.(log_ratios .- maximum(log_ratios; dims=(2, 3))) - r_eff = _generate_r_eff(weights, dims, r_eff, source) - _check_input_validity_psis(reshape(log_ratios, dims), r_eff) + r_eff = _generate_r_eff(weights_mat, dims, r_eff, source) + _check_input_validity_psis(log_ratios, r_eff) tail_length = Vector{Int}(undef, data_size) ξ = similar(r_eff) @inbounds Threads.@threads for i in eachindex(tail_length) tail_length[i] = _def_tail_length(post_sample_size, r_eff[i]) - ξ[i] = @views psis!(weights[i, :], tail_length[i]) + ξ[i] = @views psis!(weights_mat[i, :], tail_length[i]) end - @tullio norm_const[i] := weights[i, j] + @tullio norm_const[i] := weights[i, j, k] @. weights = weights / norm_const - ess = psis_ess(weights, r_eff) - inf_ess = sup_ess(weights, r_eff) - - weights = reshape(weights, dims) return Psis( weights, From ba8678d331d99fcb8c9d38e9844a2a99e9fffc0a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Sep 2021 23:07:13 +0200 Subject: [PATCH 3/3] Drop constraint to undefined type --- src/ImportanceSampling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ImportanceSampling.jl b/src/ImportanceSampling.jl index 0f02c35..295f073 100644 --- a/src/ImportanceSampling.jl +++ b/src/ImportanceSampling.jl @@ -229,7 +229,7 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer) # unsort the ratios to their original position: invpermute!(is_ratios, last.(ratio_index)) - return ξ::T + return ξ end