Skip to content

Support DPPL 0.37 #2550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Distributions = "0.25.77"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.36.3"
DynamicPPL = "0.37"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3, 1"
Libtask = "0.8.8"
Expand All @@ -90,3 +90,6 @@ julia = "1.10.2"
[extras]
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"

[sources]
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"}
9 changes: 2 additions & 7 deletions ext/TuringDynamicHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,11 @@ function DynamicPPL.initialstep(
# Ensure that initial sample is in unconstrained space.
if !DynamicPPL.islinked(vi)
vi = DynamicPPL.link!!(vi, model)
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
vi = last(DynamicPPL.evaluate!!(model, vi))
end

# Define log-density function.
= DynamicPPL.LogDensityFunction(
model,
vi,
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext());
adtype=spl.alg.adtype,
)
= DynamicPPL.LogDensityFunction(model, vi; adtype=spl.alg.adtype)

# Perform initial step.
results = DynamicHMC.mcmc_keep_warmup(
Expand Down
27 changes: 13 additions & 14 deletions ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
f = Optimisation.OptimLogDensity(model, ctx)
f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood)
init_vals = DynamicPPL.getparams(f.ldf)
optimizer = Optim.LBFGS()
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
Expand All @@ -57,8 +56,7 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
f = Optimisation.OptimLogDensity(model, ctx)
f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood)
init_vals = DynamicPPL.getparams(f.ldf)
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
end
Expand All @@ -74,8 +72,8 @@ function Optim.optimize(
end

function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood)
return _optimize(f, args...; kwargs...)
end

"""
Expand Down Expand Up @@ -104,8 +102,7 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
f = Optimisation.OptimLogDensity(model, ctx)
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
init_vals = DynamicPPL.getparams(f.ldf)
optimizer = Optim.LBFGS()
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
Expand All @@ -127,8 +124,7 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
f = Optimisation.OptimLogDensity(model, ctx)
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
init_vals = DynamicPPL.getparams(f.ldf)
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
end
Expand All @@ -144,9 +140,10 @@ function Optim.optimize(
end

function _map_optimize(model::DynamicPPL.Model, args...; kwargs...)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
return _optimize(f, args...; kwargs...)
end

"""
_optimize(f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...)
Expand All @@ -166,7 +163,9 @@ function _optimize(
# whether initialisation is really necessary at all
vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals)
vi = DynamicPPL.link(vi, f.ldf.model)
f = Optimisation.OptimLogDensity(f.ldf.model, vi, f.ldf.context; adtype=f.ldf.adtype)
f = Optimisation.OptimLogDensity(
f.ldf.model, f.ldf.getlogdensity, vi; adtype=f.ldf.adtype
)
init_vals = DynamicPPL.getparams(f.ldf)

# Optimize!
Expand All @@ -184,7 +183,7 @@ function _optimize(
vi = f.ldf.varinfo
vi_optimum = DynamicPPL.unflatten(vi, M.minimizer)
logdensity_optimum = Optimisation.OptimLogDensity(
f.ldf.model, vi_optimum, f.ldf.context
f.ldf.model, f.ldf.getlogdensity, vi_optimum; adtype=f.ldf.adtype
)
vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum)
varnames = map(Symbol first, vns_vals_iter)
Expand Down
14 changes: 6 additions & 8 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ using DynamicPPL:
push!!,
setlogp!!,
getlogp,
getlogjoint,
VarName,
getsym,
getdist,
Expand All @@ -26,9 +27,6 @@ using DynamicPPL:
SampleFromPrior,
SampleFromUniform,
DefaultContext,
PriorContext,
LikelihoodContext,
SamplingContext,
set_flag!,
unset_flag!
using Distributions, Libtask, Bijectors
Expand Down Expand Up @@ -139,7 +137,7 @@ end
Transition(θ, lp) = Transition(θ, lp, nothing)
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t)
θ = getparams(model, vi)
lp = getlogp(vi)
lp = getlogjoint(vi)
return Transition(θ, lp, getstats(t))
end

Expand All @@ -152,10 +150,10 @@ function metadata(t::Transition)
end
end

DynamicPPL.getlogp(t::Transition) = t.lp
DynamicPPL.getlogjoint(t::Transition) = t.lp

# Metadata of VarInfo object
metadata(vi::AbstractVarInfo) = (lp=getlogp(vi),)
metadata(vi::AbstractVarInfo) = (lp=getlogjoint(vi),)

##########################
# Chain making utilities #
Expand Down Expand Up @@ -218,7 +216,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
end

function get_transition_extras(ts::AbstractVector{<:VarInfo})
valmat = reshape([getlogp(t) for t in ts], :, 1)
valmat = reshape([getlogjoint(t) for t in ts], :, 1)
return [:lp], valmat
end

Expand Down Expand Up @@ -437,7 +435,7 @@ julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m`
julia> transitions = Turing.Inference.transitions_from_chain(m, chain);
julia> [Turing.Inference.getlogp(t) for t in transitions] # extract the logjoints
julia> [Turing.Inference.getlogjoint(t) for t in transitions] # extract the logjoints
2-element Array{Float64,1}:
-3.6294991938628374
-2.5697948166987845
Expand Down
29 changes: 11 additions & 18 deletions src/mcmc/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ function AbstractMCMC.step(
rng,
EllipticalSliceSampling.ESSModel(
ESSPrior(model, spl, vi),
DynamicPPL.LogDensityFunction(
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
ESSLikelihood(
DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, vi)
),
),
EllipticalSliceSampling.ESS(),
Expand All @@ -59,11 +59,11 @@ function AbstractMCMC.step(

# update sample and log-likelihood
vi = DynamicPPL.unflatten(vi, sample)
vi = setlogp!!(vi, state.loglikelihood)
vi = setloglikelihood!!(vi, state.loglikelihood)

return Transition(model, vi), vi
end

f
# Prior distribution of considered random variable
struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T}
model::M
Expand Down Expand Up @@ -97,6 +97,10 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
sampler = p.sampler
varinfo = p.varinfo
# TODO: Surely there's a better way of doing this now that we have `SamplingContext`?
# TODO(DPPL0.37/penelopeysm): This can be replaced with `init!!(p.model,
# p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason
# why we had to use the 'del' flag before this was because
# SampleFromPrior() wouldn't overwrite existing variables.
vns = keys(varinfo)
for vn in vns
set_flag!(varinfo, vn, "del")
Expand All @@ -109,19 +113,8 @@ end
Distributions.mean(p::ESSPrior) = p.μ

# Evaluate log-likelihood of proposals
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} =
DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD}

(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ, f)

function DynamicPPL.tilde_assume(
rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi
)
return DynamicPPL.tilde_assume(
rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi
)
struct ESSLogLikelihood{M<:Model,V<:AbstractVarInfo,AD<:ADTypes.AbstractADType}
ldf::DynamicPPL.LogDensityFunction{M,V,AD}
end

function DynamicPPL.tilde_observe(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vi)
return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi)
end
(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ.ldf, f)
4 changes: 2 additions & 2 deletions src/mcmc/external_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ getlogp_external(::Any, ::Any) = missing
getlogp_external(mh::AdvancedMH.Transition, ::AdvancedMH.Transition) = mh.lp
getlogp_external(hmc::AdvancedHMC.Transition, ::AdvancedHMC.HMCState) = hmc.stat.log_density

struct TuringState{S,V1<:AbstractVarInfo,M,V,C}
struct TuringState{S,V1<:AbstractVarInfo,M,V}
state::S
# Note that this varinfo has the correct parameters and logp obtained from
# the state, whereas `ldf.varinfo` will in general have junk inside it.
varinfo::V1
ldf::DynamicPPL.LogDensityFunction{M,V,C}
ldf::DynamicPPL.LogDensityFunction{M,V}
end

varinfo(state::TuringState) = state.varinfo
Expand Down
14 changes: 8 additions & 6 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(ctx.context)
#
# Purpose: avoid triggering resampling of variables we're conditioning on.
# - Using standard `DynamicPPL.condition` results in conditioned variables being treated
# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`.
# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe!!`.
# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to
# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable
# rather than only for the "true" observations.
Expand Down Expand Up @@ -178,24 +178,26 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
DynamicPPL.tilde_assume(child_context, right, vn, vi)
elseif has_conditioned_gibbs(context, vn)
# Short-circuit the tilde assume if `vn` is present in `context`.
value, lp, _ = DynamicPPL.tilde_assume(
# TODO(mhauru) Fix accumulation here. In this branch anything that gets
# accumulated just gets discarded with `_`.
value, _ = DynamicPPL.tilde_assume(
child_context, right, vn, get_global_varinfo(context)
)
value, lp, vi
value, vi
else
# If the varname has not been conditioned on, nor is it a target variable, its
# presumably a new variable that should be sampled from its prior. We need to add
# this new variable to the global `varinfo` of the context, but not to the local one
# being used by the current sampler.
value, lp, new_global_vi = DynamicPPL.tilde_assume(
value, new_global_vi = DynamicPPL.tilde_assume(
child_context,
DynamicPPL.SampleFromPrior(),
right,
vn,
get_global_varinfo(context),
)
set_global_varinfo!(context, new_global_vi)
value, lp, vi
value, vi
end
end

Expand Down Expand Up @@ -557,7 +559,7 @@ function setparams_varinfo!!(
params::AbstractVarInfo,
)
logdensity = DynamicPPL.LogDensityFunction(
model, state.ldf.varinfo, state.ldf.context; adtype=sampler.alg.adtype
model, state.ldf.varinfo; adtype=sampler.alg.adtype
)
new_inner_state = setparams_varinfo!!(
AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params
Expand Down
37 changes: 9 additions & 28 deletions src/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,7 @@ function DynamicPPL.initialstep(
# Create a Hamiltonian.
metricT = getmetricT(spl.alg)
metric = metricT(length(theta))
ldf = DynamicPPL.LogDensityFunction(
model,
vi,
# TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we
# need to pass in the sampler? (In fact LogDensityFunction defaults to
# using leafcontext(model.context) so could we just remove the argument
# entirely?)
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context));
adtype=spl.alg.adtype,
Comment on lines -196 to -201
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As established in (e.g.) TuringLang/DynamicPPL.jl#955 (comment) SamplingContext for Hamiltonians was never overloaded so it is equivalent to just use DefaultContext in the LDF.

)
ldf = DynamicPPL.LogDensityFunction(model, vi; adtype=spl.alg.adtype)
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)
Expand All @@ -214,7 +205,7 @@ function DynamicPPL.initialstep(
theta = vi[:]

# Cache current log density.
log_density_old = getlogp(vi)
log_density_old = getloglikelihood(vi)

# Find good eps if not provided one
if iszero(spl.alg.ϵ)
Expand Down Expand Up @@ -242,10 +233,12 @@ function DynamicPPL.initialstep(
# Update `vi` based on acceptance
if t.stat.is_accept
vi = DynamicPPL.unflatten(vi, t.z.θ)
vi = setlogp!!(vi, t.stat.log_density)
# TODO(mhauru) Is setloglikelihood! the right thing here?
vi = setloglikelihood!!(vi, t.stat.log_density)
else
vi = DynamicPPL.unflatten(vi, theta)
vi = setlogp!!(vi, log_density_old)
# TODO(mhauru) Is setloglikelihood! the right thing here?
vi = setloglikelihood!!(vi, log_density_old)
Comment on lines -245 to +241
Copy link
Member

@penelopeysm penelopeysm Jul 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not fully sure what to do here, t.stat.log_density is a joint and not sure how to decompose it into likelihood + prior.

There is an argument that we should re-evaluate the model anyway because t.stat.log_density provides a logpdf in linked space whereas for user-facing purposes we want the invlinked logpdf #2617. And I don't immediately see how to calculate this without re-evaluating the model anyway

end

transition = Transition(model, vi, t)
Expand Down Expand Up @@ -290,7 +283,8 @@ function AbstractMCMC.step(
vi = state.vi
if t.stat.is_accept
vi = DynamicPPL.unflatten(vi, t.z.θ)
vi = setlogp!!(vi, t.stat.log_density)
# TODO(mhauru) Is setloglikelihood! the right thing here?
vi = setloglikelihood!!(vi, t.stat.log_density)
end

# Compute next transition and state.
Expand All @@ -302,16 +296,7 @@ end

function get_hamiltonian(model, spl, vi, state, n)
metric = gen_metric(n, spl, state)
ldf = DynamicPPL.LogDensityFunction(
model,
vi,
# TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we
# need to pass in the sampler? (In fact LogDensityFunction defaults to
# using leafcontext(model.context) so could we just remove the argument
# entirely?)
DynamicPPL.SamplingContext(spl, DynamicPPL.leafcontext(model.context));
adtype=spl.alg.adtype,
)
ldf = DynamicPPL.LogDensityFunction(model, vi; adtype=spl.alg.adtype)
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
return AHMC.Hamiltonian(metric, lp_func, lp_grad_func)
Expand Down Expand Up @@ -516,10 +501,6 @@ function DynamicPPL.assume(
return DynamicPPL.assume(dist, vn, vi)
end

function DynamicPPL.observe(::Sampler{<:Hamiltonian}, d::Distribution, value, vi)
return DynamicPPL.observe(d, value, vi)
end

####
#### Default HMC stepsize and mass matrix adaptor
####
Expand Down
4 changes: 0 additions & 4 deletions src/mcmc/is.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,3 @@ function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName
end
return r, 0, vi
end

function DynamicPPL.observe(::Sampler{<:IS}, dist::Distribution, value, vi)
return logpdf(dist, value), vi
end
Loading
Loading