diff --git a/Project.toml b/Project.toml index 3956c079fd..cdf6826368 100644 --- a/Project.toml +++ b/Project.toml @@ -64,7 +64,7 @@ Distributions = "0.25.77" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.36.3" +DynamicPPL = "0.37" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" Libtask = "0.8.8" @@ -90,3 +90,6 @@ julia = "1.10.2" [extras] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" Optim = "429524aa-4258-5aef-a3af-852621145aeb" + +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} diff --git a/ext/TuringDynamicHMCExt.jl b/ext/TuringDynamicHMCExt.jl index 5718e3855a..8a34d26498 100644 --- a/ext/TuringDynamicHMCExt.jl +++ b/ext/TuringDynamicHMCExt.jl @@ -58,15 +58,12 @@ function DynamicPPL.initialstep( # Ensure that initial sample is in unconstrained space. if !DynamicPPL.islinked(vi) vi = DynamicPPL.link!!(vi, model) - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) + vi = last(DynamicPPL.evaluate!!(model, vi)) end # Define log-density function. ℓ = DynamicPPL.LogDensityFunction( - model, - vi, - DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()); - adtype=spl.alg.adtype, + model, DynamicPPL.getlogjoint, vi; adtype=spl.alg.adtype ) # Perform initial step. diff --git a/ext/TuringOptimExt.jl b/ext/TuringOptimExt.jl index d6c253e2a2..635eb89111 100644 --- a/ext/TuringOptimExt.jl +++ b/ext/TuringOptimExt.jl @@ -34,8 +34,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - f = Optimisation.OptimLogDensity(model, ctx) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood) init_vals = DynamicPPL.getparams(f.ldf) optimizer = Optim.LBFGS() return _mle_optimize(model, init_vals, optimizer, options; kwargs...) @@ -57,8 +56,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - f = Optimisation.OptimLogDensity(model, ctx) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood) init_vals = DynamicPPL.getparams(f.ldf) return _mle_optimize(model, init_vals, optimizer, options; kwargs...) end @@ -74,8 +72,8 @@ function Optim.optimize( end function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...) - ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood) + return _optimize(f, args...; kwargs...) end """ @@ -104,8 +102,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - f = Optimisation.OptimLogDensity(model, ctx) + f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian) init_vals = DynamicPPL.getparams(f.ldf) optimizer = Optim.LBFGS() return _map_optimize(model, init_vals, optimizer, options; kwargs...) @@ -127,8 +124,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - f = Optimisation.OptimLogDensity(model, ctx) + f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian) init_vals = DynamicPPL.getparams(f.ldf) return _map_optimize(model, init_vals, optimizer, options; kwargs...) end @@ -144,9 +140,10 @@ function Optim.optimize( end function _map_optimize(model::DynamicPPL.Model, args...; kwargs...) - ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...) + f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian) + return _optimize(f, args...; kwargs...) end + """ _optimize(f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...) @@ -166,7 +163,9 @@ function _optimize( # whether initialisation is really necessary at all vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals) vi = DynamicPPL.link(vi, f.ldf.model) - f = Optimisation.OptimLogDensity(f.ldf.model, vi, f.ldf.context; adtype=f.ldf.adtype) + f = Optimisation.OptimLogDensity( + f.ldf.model, f.ldf.getlogdensity, vi; adtype=f.ldf.adtype + ) init_vals = DynamicPPL.getparams(f.ldf) # Optimize! @@ -184,7 +183,7 @@ function _optimize( vi = f.ldf.varinfo vi_optimum = DynamicPPL.unflatten(vi, M.minimizer) logdensity_optimum = Optimisation.OptimLogDensity( - f.ldf.model, vi_optimum, f.ldf.context + f.ldf.model, f.ldf.getlogdensity, vi_optimum; adtype=f.ldf.adtype ) vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum) varnames = map(Symbol ∘ first, vns_vals_iter) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 0370e619a3..15efe2ad18 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -18,6 +18,7 @@ using DynamicPPL: push!!, setlogp!!, getlogp, + getlogjoint, VarName, getsym, getdist, @@ -26,9 +27,6 @@ using DynamicPPL: SampleFromPrior, SampleFromUniform, DefaultContext, - PriorContext, - LikelihoodContext, - SamplingContext, set_flag!, unset_flag! using Distributions, Libtask, Bijectors @@ -139,7 +137,7 @@ end Transition(θ, lp) = Transition(θ, lp, nothing) function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t) θ = getparams(model, vi) - lp = getlogp(vi) + lp = getlogjoint(vi) return Transition(θ, lp, getstats(t)) end @@ -152,10 +150,10 @@ function metadata(t::Transition) end end -DynamicPPL.getlogp(t::Transition) = t.lp +DynamicPPL.getlogjoint(t::Transition) = t.lp # Metadata of VarInfo object -metadata(vi::AbstractVarInfo) = (lp=getlogp(vi),) +metadata(vi::AbstractVarInfo) = (lp=getlogjoint(vi),) ########################## # Chain making utilities # @@ -218,7 +216,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector) end function get_transition_extras(ts::AbstractVector{<:VarInfo}) - valmat = reshape([getlogp(t) for t in ts], :, 1) + valmat = reshape([getlogjoint(t) for t in ts], :, 1) return [:lp], valmat end @@ -437,7 +435,7 @@ julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m` julia> transitions = Turing.Inference.transitions_from_chain(m, chain); -julia> [Turing.Inference.getlogp(t) for t in transitions] # extract the logjoints +julia> [Turing.Inference.getlogjoint(t) for t in transitions] # extract the logjoints 2-element Array{Float64,1}: -3.6294991938628374 -2.5697948166987845 diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index fd4d441bdd..4d55d5c698 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -1,7 +1,9 @@ # TODO: Implement additional checks for certain samplers, e.g. # HMC not supporting discrete parameters. function _check_model(model::DynamicPPL.Model) - return DynamicPPL.check_model(model; error_on_failure=true) + # TODO(DPPL0.37/penelopeysm): use InitContext + spl_model = DynamicPPL.contextualize(model, DynamicPPL.SamplingContext(model.context)) + return DynamicPPL.check_model(spl_model, VarInfo(); error_on_failure=true) end function _check_model(model::DynamicPPL.Model, alg::InferenceAlgorithm) return _check_model(model) diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index dfd1fc0d30..6f80dea114 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -53,10 +53,14 @@ function AbstractMCMC.step( length(initial_params) == n || throw(ArgumentError("initial parameters have to be specified for each walker")) vis = map(vis, initial_params) do vi, init + # TODO(DPPL0.37/penelopeysm) This whole thing can be replaced with init!! vi = DynamicPPL.initialize_parameters!!(vi, init, model) # Update log joint probability. - last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromPrior())) + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, SampleFromPrior(), model.context) + ) + last(DynamicPPL.evaluate!!(spl_model, vi)) end end @@ -68,7 +72,7 @@ function AbstractMCMC.step( vis[1], map(vis) do vi vi = DynamicPPL.link!!(vi, model) - AMH.Transition(vi[:], getlogp(vi), false) + AMH.Transition(vi[:], DynamicPPL.getlogjoint(vi), false) end, ) diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 5448173486..86b92b28ee 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -24,7 +24,7 @@ struct ESS <: InferenceAlgorithm end # always accept in the first step function DynamicPPL.initialstep( - rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... + rng::AbstractRNG, model::Model, ::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... ) for vn in keys(vi) dist = getdist(vi, vn) @@ -35,45 +35,37 @@ function DynamicPPL.initialstep( end function AbstractMCMC.step( - rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... + rng::AbstractRNG, model::Model, ::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... ) # obtain previous sample f = vi[:] # define previous sampler state # (do not use cache to avoid in-place sampling from prior) - oldstate = EllipticalSliceSampling.ESSState(f, getlogp(vi), nothing) + oldstate = EllipticalSliceSampling.ESSState(f, DynamicPPL.getloglikelihood(vi), nothing) # compute next state sample, state = AbstractMCMC.step( rng, - EllipticalSliceSampling.ESSModel( - ESSPrior(model, spl, vi), - DynamicPPL.LogDensityFunction( - model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()) - ), - ), + EllipticalSliceSampling.ESSModel(ESSPrior(model, vi), ESSLikelihood(model, vi)), EllipticalSliceSampling.ESS(), oldstate, ) # update sample and log-likelihood vi = DynamicPPL.unflatten(vi, sample) - vi = setlogp!!(vi, state.loglikelihood) + vi = setloglikelihood!!(vi, state.loglikelihood) return Transition(model, vi), vi end # Prior distribution of considered random variable -struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T} +struct ESSPrior{M<:Model,V<:AbstractVarInfo,T} model::M - sampler::S varinfo::V μ::T - function ESSPrior{M,S,V}( - model::M, sampler::S, varinfo::V - ) where {M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} + function ESSPrior(model::Model, varinfo::AbstractVarInfo) vns = keys(varinfo) μ = mapreduce(vcat, vns) do vn dist = getdist(varinfo, vn) @@ -81,47 +73,43 @@ struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T} error("[ESS] only supports Gaussian prior distributions") DynamicPPL.tovec(mean(dist)) end - return new{M,S,V,typeof(μ)}(model, sampler, varinfo, μ) + return new{typeof(model),typeof(varinfo),typeof(μ)}(model, varinfo, μ) end end -function ESSPrior(model::Model, sampler::Sampler{<:ESS}, varinfo::AbstractVarInfo) - return ESSPrior{typeof(model),typeof(sampler),typeof(varinfo)}(model, sampler, varinfo) -end - # Ensure that the prior is a Gaussian distribution (checked in the constructor) EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true # Only define out-of-place sampling function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) - sampler = p.sampler varinfo = p.varinfo # TODO: Surely there's a better way of doing this now that we have `SamplingContext`? + # TODO(DPPL0.37/penelopeysm): This can be replaced with `init!!(p.model, + # p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason + # why we had to use the 'del' flag before this was because + # SampleFromPrior() wouldn't overwrite existing variables. vns = keys(varinfo) for vn in vns set_flag!(varinfo, vn, "del") end - p.model(rng, varinfo, sampler) + p.model(rng, varinfo) return varinfo[:] end # Mean of prior distribution Distributions.mean(p::ESSPrior) = p.μ -# Evaluate log-likelihood of proposals -const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} = - DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD} +# Evaluate log-likelihood of proposals. We need this struct because +# EllipticalSliceSampling.jl expects a callable struct / a function as its +# likelihood. +struct ESSLikelihood{M<:Model,V<:AbstractVarInfo} + ldf::DynamicPPL.LogDensityFunction{M,V} -(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ, f) - -function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi -) - return DynamicPPL.tilde_assume( - rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi - ) + # Force usage of `getloglikelihood` in inner constructor + function ESSLogLikelihood(model::Model, varinfo::AbstractVarInfo) + ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, varinfo) + return new{typeof(model),typeof(varinfo)}(ldf) + end end -function DynamicPPL.tilde_observe(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vi) - return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi) -end +(ℓ::ESSLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ.ldf, f) diff --git a/src/mcmc/external_sampler.jl b/src/mcmc/external_sampler.jl index 7fa7692e4c..992a2fb2db 100644 --- a/src/mcmc/external_sampler.jl +++ b/src/mcmc/external_sampler.jl @@ -96,12 +96,12 @@ getlogp_external(::Any, ::Any) = missing getlogp_external(mh::AdvancedMH.Transition, ::AdvancedMH.Transition) = mh.lp getlogp_external(hmc::AdvancedHMC.Transition, ::AdvancedHMC.HMCState) = hmc.stat.log_density -struct TuringState{S,V1<:AbstractVarInfo,M,V,C} +struct TuringState{S,V1<:AbstractVarInfo,M,V} state::S # Note that this varinfo has the correct parameters and logp obtained from # the state, whereas `ldf.varinfo` will in general have junk inside it. varinfo::V1 - ldf::DynamicPPL.LogDensityFunction{M,V,C} + ldf::DynamicPPL.LogDensityFunction{M,V} end varinfo(state::TuringState) = state.varinfo @@ -126,7 +126,13 @@ function make_updated_varinfo( return if ismissing(new_logp) last(DynamicPPL.evaluate!!(f.model, new_varinfo, f.context)) else - DynamicPPL.setlogp!!(new_varinfo, new_logp) + # TODO(DPPL0.37/penelopeysm) This is obviously wrong. Note that we + # have the same problem here as in HMC in that the sampler doesn't + # tell us about how logp is broken down into prior and likelihood. + # We should probably just re-evaluate unconditionally. A bit + # unfortunate. + DynamicPPL.setlogprior!!(new_varinfo, 0.0) + DynamicPPL.setloglikelihood!!(new_varinfo, new_logp) end end @@ -156,7 +162,9 @@ function AbstractMCMC.step( end # Construct LogDensityFunction - f = DynamicPPL.LogDensityFunction(model, varinfo; adtype=alg.adtype) + f = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint, varinfo; adtype=alg.adtype + ) # Then just call `AbstractMCMC.step` with the right arguments. if initial_state === nothing diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index f36cb9c364..81281389ec 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -33,7 +33,7 @@ can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(ctx.context) # # Purpose: avoid triggering resampling of variables we're conditioning on. # - Using standard `DynamicPPL.condition` results in conditioned variables being treated -# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. +# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe!!`. # - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to # undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable # rather than only for the "true" observations. @@ -178,16 +178,18 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) DynamicPPL.tilde_assume(child_context, right, vn, vi) elseif has_conditioned_gibbs(context, vn) # Short-circuit the tilde assume if `vn` is present in `context`. - value, lp, _ = DynamicPPL.tilde_assume( + # TODO(mhauru) Fix accumulation here. In this branch anything that gets + # accumulated just gets discarded with `_`. + value, _ = DynamicPPL.tilde_assume( child_context, right, vn, get_global_varinfo(context) ) - value, lp, vi + value, vi else # If the varname has not been conditioned on, nor is it a target variable, its # presumably a new variable that should be sampled from its prior. We need to add # this new variable to the global `varinfo` of the context, but not to the local one # being used by the current sampler. - value, lp, new_global_vi = DynamicPPL.tilde_assume( + value, new_global_vi = DynamicPPL.tilde_assume( child_context, DynamicPPL.SampleFromPrior(), right, @@ -195,7 +197,7 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) get_global_varinfo(context), ) set_global_varinfo!(context, new_global_vi) - value, lp, vi + value, vi end end @@ -210,12 +212,12 @@ function DynamicPPL.tilde_assume( return if is_target_varname(context, vn) DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi) elseif has_conditioned_gibbs(context, vn) - value, lp, _ = DynamicPPL.tilde_assume( + value, _ = DynamicPPL.tilde_assume( child_context, right, vn, get_global_varinfo(context) ) - value, lp, vi + value, vi else - value, lp, new_global_vi = DynamicPPL.tilde_assume( + value, new_global_vi = DynamicPPL.tilde_assume( rng, child_context, DynamicPPL.SampleFromPrior(), @@ -224,7 +226,7 @@ function DynamicPPL.tilde_assume( get_global_varinfo(context), ) set_global_varinfo!(context, new_global_vi) - value, lp, vi + value, vi end end @@ -347,7 +349,7 @@ function initial_varinfo(rng, model, spl, initial_params) # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 # and https://github.com/TuringLang/Turing.jl/issues/1563 # to avoid that existing variables are resampled - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.DefaultContext())) + vi = last(DynamicPPL.evaluate!!(model, vi)) end return vi end @@ -532,9 +534,7 @@ function setparams_varinfo!!( ) # The state is already a VarInfo, so we can just return `params`, but first we need to # update its logprob. - # NOTE: Using `leafcontext(model.context)` here is a no-op, as it will be concatenated - # with `model.context` before hitting `model.f`. - return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.leafcontext(model.context))) + return last(DynamicPPL.evaluate!!(model, params)) end function setparams_varinfo!!( @@ -544,10 +544,8 @@ function setparams_varinfo!!( params::AbstractVarInfo, ) # The state is already a VarInfo, so we can just return `params`, but first we need to - # update its logprob. To do this, we have to call evaluate!! with the sampler, rather - # than just a context, because ESS is peculiar in how it uses LikelihoodContext for - # some variables and DefaultContext for others. - return last(DynamicPPL.evaluate!!(model, params, SamplingContext(sampler))) + # update its logprob. + return last(DynamicPPL.evaluate!!(model, params)) end function setparams_varinfo!!( @@ -557,7 +555,7 @@ function setparams_varinfo!!( params::AbstractVarInfo, ) logdensity = DynamicPPL.LogDensityFunction( - model, state.ldf.varinfo, state.ldf.context; adtype=sampler.alg.adtype + model, DynamicPPL.getlogjoint, state.ldf.varinfo; adtype=sampler.alg.adtype ) new_inner_state = setparams_varinfo!!( AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index b5f51587b1..e19f023437 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -191,14 +191,7 @@ function DynamicPPL.initialstep( metricT = getmetricT(spl.alg) metric = metricT(length(theta)) ldf = DynamicPPL.LogDensityFunction( - model, - vi, - # TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we - # need to pass in the sampler? (In fact LogDensityFunction defaults to - # using leafcontext(model.context) so could we just remove the argument - # entirely?) - DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)); - adtype=spl.alg.adtype, + model, DynamicPPL.getlogjoint, vi; adtype=spl.alg.adtype ) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) @@ -213,8 +206,8 @@ function DynamicPPL.initialstep( end theta = vi[:] - # Cache current log density. - log_density_old = getlogp(vi) + # Cache current log density. We will reuse this if the transition is rejected. + logp_old = DynamicPPL.getlogp(vi) # Find good eps if not provided one if iszero(spl.alg.ϵ) @@ -239,13 +232,21 @@ function DynamicPPL.initialstep( ) end - # Update `vi` based on acceptance + # Update VarInfo based on acceptance if t.stat.is_accept vi = DynamicPPL.unflatten(vi, t.z.θ) - vi = setlogp!!(vi, t.stat.log_density) + # Re-evaluate to calculate log probability density. + # TODO(penelopeysm): This seems a little bit wasteful. Unfortunately, + # even though `t.stat.log_density` contains some kind of logp, this + # doesn't track prior and likelihood separately but rather a single + # log-joint (and in linked space), so which we have no way to decompose + # this back into prior and likelihood. I don't immediately see how to + # solve this without re-evaluating the model. + _, vi = DynamicPPL.evaluate!!(model, vi) else + # Reset VarInfo back to its original state. vi = DynamicPPL.unflatten(vi, theta) - vi = setlogp!!(vi, log_density_old) + vi = DynamicPPL.setlogp!!(vi, logp_old) end transition = Transition(model, vi, t) @@ -290,7 +291,9 @@ function AbstractMCMC.step( vi = state.vi if t.stat.is_accept vi = DynamicPPL.unflatten(vi, t.z.θ) - vi = setlogp!!(vi, t.stat.log_density) + # Re-evaluate to calculate log probability density. + # TODO(penelopeysm): This seems a little bit wasteful. See note above. + _, vi = DynamicPPL.evaluate!!(model, vi) end # Compute next transition and state. @@ -303,14 +306,7 @@ end function get_hamiltonian(model, spl, vi, state, n) metric = gen_metric(n, spl, state) ldf = DynamicPPL.LogDensityFunction( - model, - vi, - # TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we - # need to pass in the sampler? (In fact LogDensityFunction defaults to - # using leafcontext(model.context) so could we just remove the argument - # entirely?) - DynamicPPL.SamplingContext(spl, DynamicPPL.leafcontext(model.context)); - adtype=spl.alg.adtype, + model, DynamicPPL.getlogjoint, vi; adtype=spl.alg.adtype ) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) @@ -516,10 +512,6 @@ function DynamicPPL.assume( return DynamicPPL.assume(dist, vn, vi) end -function DynamicPPL.observe(::Sampler{<:Hamiltonian}, d::Distribution, value, vi) - return DynamicPPL.observe(d, value, vi) -end - #### #### Default HMC stepsize and mass matrix adaptor #### diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index d83abd173c..5f2f1627fb 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -31,14 +31,20 @@ DynamicPPL.initialsampler(sampler::Sampler{<:IS}) = sampler function DynamicPPL.initialstep( rng::AbstractRNG, model::Model, spl::Sampler{<:IS}, vi::AbstractVarInfo; kwargs... ) - return Transition(model, vi), nothing + # Need to manually construct the Transition here because we only + # want to use the likelihood. + xs = Turing.Inference.getparams(model, vi) + lp = DynamicPPL.getloglikelihood(vi) + return Transition(xs, lp, nothing), nothing end function AbstractMCMC.step( rng::Random.AbstractRNG, model::Model, spl::Sampler{<:IS}, ::Nothing; kwargs... ) vi = VarInfo(rng, model, spl) - return Transition(model, vi), nothing + xs = Turing.Inference.getparams(model, vi) + lp = DynamicPPL.getloglikelihood(vi) + return Transition(xs, lp, nothing), nothing end # Calculate evidence. @@ -53,9 +59,6 @@ function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName r = rand(rng, dist) vi = push!!(vi, vn, r, dist) end - return r, 0, vi -end - -function DynamicPPL.observe(::Sampler{<:IS}, dist::Distribution, value, vi) - return logpdf(dist, value), vi + vi = DynamicPPL.accumulate_assume!!(vi, r, 0.0, vn, dist) + return r, vi end diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index fb50c5f582..019af79391 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -157,6 +157,8 @@ end # Utility functions # ##################### +# TODO(DPPL0.37/penelopeysm): This function should no longer be needed +# once InitContext is merged. """ set_namedtuple!(vi::VarInfo, nt::NamedTuple) @@ -181,21 +183,19 @@ function set_namedtuple!(vi::DynamicPPL.VarInfoOrThreadSafeVarInfo, nt::NamedTup end end -""" - MHLogDensityFunction - -A log density function for the MH sampler. - -This variant uses the `set_namedtuple!` function to update the `VarInfo`. -""" -const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = - DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD} - -function LogDensityProblems.logdensity(f::MHLogDensityFunction, x::NamedTuple) +# NOTE(penelopeysm): MH does not conform to the usual LogDensityProblems +# interface in that it gets evaluated with a NamedTuple. Hence we need this +# method just to deal with MH. +# TODO(DPPL0.37/penelopeysm): Check the extent to which this method is actually +# needed. If it's still needed, replace this with `init!!(f.model, f.varinfo, +# ParamsInit(x))`. Much less hacky than `set_namedtuple!` (hopefully...). +# In general, we should much prefer to either (1) conform to the +# LogDensityProblems interface or (2) use VarNames anyway. +function LogDensityProblems.logdensity(f::LogDensityFunction, x::NamedTuple) vi = deepcopy(f.varinfo) set_namedtuple!(vi, x) vi_new = last(DynamicPPL.evaluate!!(f.model, vi, f.context)) - lj = getlogp(vi_new) + lj = f.getlogdensity(vi_new) return lj end @@ -304,7 +304,7 @@ function propose!!( # Create a sampler and the previous transition. mh_sampler = AMH.MetropolisHastings(dt) - prev_trans = AMH.Transition(vt, getlogp(vi), false) + prev_trans = AMH.Transition(vt, DynamicPPL.getlogjoint(vi), false) # Make a new transition. densitymodel = AMH.DensityModel( @@ -339,7 +339,7 @@ function propose!!( # Create a sampler and the previous transition. mh_sampler = AMH.MetropolisHastings(spl.alg.proposals) - prev_trans = AMH.Transition(vals, getlogp(vi), false) + prev_trans = AMH.Transition(vals, DynamicPPL.getlogjoint(vi), false) # Make a new transition. densitymodel = AMH.DensityModel( @@ -392,7 +392,3 @@ function DynamicPPL.assume( retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi) return retval end - -function DynamicPPL.observe(spl::Sampler{<:MH}, d::Distribution, value, vi) - return DynamicPPL.observe(SampleFromPrior(), d, value, vi) -end diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index ffc1019519..549a4a02df 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -18,15 +18,16 @@ function TracedModel( varinfo::AbstractVarInfo, rng::Random.AbstractRNG, ) - context = SamplingContext(rng, sampler, DefaultContext()) - args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) + spl_context = DynamicPPL.SamplingContext(rng, sampler, model.context) + spl_model = DynamicPPL.contextualize(model, spl_context) + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(spl_model, varinfo) if kwargs !== nothing && !isempty(kwargs) error( "Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.", ) end return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( - model, sampler, varinfo, (model.f, args...) + spl_model, sampler, varinfo, (spl_model.f, args...) ) end @@ -34,13 +35,15 @@ function AdvancedPS.advance!( trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false ) # Make sure we load/reset the rng in the new replaying mechanism - DynamicPPL.increment_num_produce!(trace.model.f.varinfo) + trace = Accessors.@set trace.model.f.varinfo = DynamicPPL.increment_num_produce!!( + trace.model.f.varinfo + ) isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng) score = consume(trace.model.ctask) if score === nothing return nothing else - return score + DynamicPPL.getlogp(trace.model.f.varinfo) + return score + DynamicPPL.getlogjoint(trace.model.f.varinfo) end end @@ -50,13 +53,11 @@ function AdvancedPS.delete_retained!(trace::TracedModel) end function AdvancedPS.reset_model(trace::TracedModel) - DynamicPPL.reset_num_produce!(trace.varinfo) - return trace + return Accessors.@set trace.varinfo = DynamicPPL.reset_num_produce!!(trace.varinfo) end function AdvancedPS.reset_logprob!(trace::TracedModel) - DynamicPPL.resetlogp!!(trace.model.varinfo) - return trace + return Accessors.@set trace.model.varinfo = DynamicPPL.resetlogp!!(trace.model.varinfo) end function AdvancedPS.update_rng!( @@ -127,7 +128,7 @@ function SMCTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, weight) # This is pretty useless since we reset the log probability continuously in the # particle sweep. - lp = getlogp(vi) + lp = DynamicPPL.getlogjoint(vi) return SMCTransition(theta, lp, weight) end @@ -193,10 +194,10 @@ function DynamicPPL.initialstep( kwargs..., ) # Reset the VarInfo. - DynamicPPL.reset_num_produce!(vi) + vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) - DynamicPPL.resetlogp!!(vi) - DynamicPPL.empty!!(vi) + vi = DynamicPPL.resetlogp!!(vi) + vi = DynamicPPL.empty!!(vi) # Create a new set of particles. particles = AdvancedPS.ParticleContainer( @@ -306,7 +307,7 @@ function PGTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, logevidence) # This is pretty useless since we reset the log probability continuously in the # particle sweep. - lp = getlogp(vi) + lp = DynamicPPL.getlogjoint(vi) return PGTransition(theta, lp, logevidence) end @@ -327,9 +328,9 @@ function DynamicPPL.initialstep( kwargs..., ) # Reset the VarInfo before new sweep - DynamicPPL.reset_num_produce!(vi) + vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) - DynamicPPL.resetlogp!!(vi) + vi = DynamicPPL.resetlogp!!(vi) # Create a new set of particles num_particles = spl.alg.nparticles @@ -359,8 +360,8 @@ function AbstractMCMC.step( ) # Reset the VarInfo before new sweep. vi = state.vi - DynamicPPL.reset_num_produce!(vi) - DynamicPPL.resetlogp!!(vi) + vi = DynamicPPL.reset_num_produce!!(vi) + vi = DynamicPPL.resetlogp!!(vi) # Create reference particle for which the samples will be retained. reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi, state.rng)) @@ -395,7 +396,7 @@ function AbstractMCMC.step( end function DynamicPPL.use_threadsafe_eval( - ::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, ::AbstractVarInfo + ::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, ::AbstractVarInfo ) return false end @@ -428,6 +429,8 @@ function trace_local_rng_maybe(rng::Random.AbstractRNG) end end +# TODO(DPPL0.37/penelopeysm) The whole tilde pipeline for particle MCMC needs to be +# thoroughly fixed. function DynamicPPL.assume( rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, _vi::AbstractVarInfo ) @@ -441,33 +444,36 @@ function DynamicPPL.assume( DynamicPPL.unset_flag!(vi, vn, "del") # Reference particle parent r = rand(trng, dist) vi[vn] = DynamicPPL.tovec(r) - DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) + vi = DynamicPPL.setorder!!(vi, vn, DynamicPPL.get_num_produce(vi)) else r = vi[vn] end - # TODO: Should we make this `zero(promote_type(eltype(dist), eltype(r)))` or something? - lp = 0 - return r, lp, vi + # TODO: call accumulate_assume?! + return r, vi end -function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi) - # NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`. - return logpdf(dist, value), trace_local_varinfo_maybe(vi) -end +# TODO(mhauru) Fix this. +# function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi) +# # NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`. +# return logpdf(dist, value), trace_local_varinfo_maybe(vi) +# end function DynamicPPL.acclogp!!( - context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp + context::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, + varinfo::AbstractVarInfo, + logp, ) varinfo_trace = trace_local_varinfo_maybe(varinfo) return DynamicPPL.acclogp!!(DynamicPPL.childcontext(context), varinfo_trace, logp) end -function DynamicPPL.acclogp_observe!!( - context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp -) - Libtask.produce(logp) - return trace_local_varinfo_maybe(varinfo) -end +# TODO(mhauru) Fix this. +# function DynamicPPL.acclogp_observe!!( +# context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp +# ) +# Libtask.produce(logp) +# return trace_local_varinfo_maybe(varinfo) +# end # Convenient constructor function AdvancedPS.Trace( @@ -477,7 +483,7 @@ function AdvancedPS.Trace( rng::AdvancedPS.TracedRNG, ) newvarinfo = deepcopy(varinfo) - DynamicPPL.reset_num_produce!(newvarinfo) + newvarinfo = DynamicPPL.reset_num_produce!!(newvarinfo) tmodel = TracedModel(model, sampler, newvarinfo, rng) newtrace = AdvancedPS.Trace(tmodel, rng) diff --git a/src/mcmc/prior.jl b/src/mcmc/prior.jl index c7a5cc5737..eadeaceb38 100644 --- a/src/mcmc/prior.jl +++ b/src/mcmc/prior.jl @@ -12,14 +12,16 @@ function AbstractMCMC.step( state=nothing; kwargs..., ) - vi = last( - DynamicPPL.evaluate!!( - model, - VarInfo(), - SamplingContext(rng, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext()), - ), + # TODO(DPPL0.37/penelopeysm): replace with init!! + sampling_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context) ) - return vi, nothing + _, vi = DynamicPPL.evaluate!!(sampling_model, VarInfo()) + # Need to manually construct the Transition here because we only + # want to use the prior probability. + xs = Turing.Inference.getparams(model, vi) + lp = DynamicPPL.getlogprior(vi) + return Transition(xs, lp, nothing), nothing end DynamicPPL.default_chain_type(sampler::Prior) = MCMCChains.Chains diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 0c322244eb..2d669cd908 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -200,7 +200,7 @@ end function SGLDTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, stepsize) theta = getparams(model, vi) - lp = getlogp(vi) + lp = DynamicPPL.getlogjoint(vi) return SGLDTransition(theta, lp, stepsize) end diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index ddcc27b876..058514f60f 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -43,75 +43,129 @@ Concrete type for maximum a posteriori estimation. Only used for the Optim.jl in """ struct MAP <: ModeEstimator end +# Most of these functions for LogPriorWithoutJacobianAccumulator are copied from +# LogPriorAccumulator. The only one that is different is the accumulate_assume!! one. """ - OptimizationContext{C<:AbstractContext} <: AbstractContext + LogPriorWithoutJacobianAccumulator{T} <: DynamicPPL.AbstractAccumulator -The `OptimizationContext` transforms variables to their constrained space, but -does not use the density with respect to the transformation. This context is -intended to allow an optimizer to sample in R^n freely. +Exactly like DynamicPPL.LogPriorAccumulator, but does not include the log determinant of the +Jacobian of any variable transformations. + +Used for MAP optimisation. """ -struct OptimizationContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext - context::C +struct LogPriorWithoutJacobianAccumulator{T} <: DynamicPPL.AbstractAccumulator + logp::T +end - function OptimizationContext{C}(context::C) where {C<:DynamicPPL.AbstractContext} - if !( - context isa Union{ - DynamicPPL.DefaultContext, - DynamicPPL.LikelihoodContext, - DynamicPPL.PriorContext, - } - ) - msg = """ - `OptimizationContext` supports only leaf contexts of type - `DynamicPPL.DefaultContext`, `DynamicPPL.LikelihoodContext`, - and `DynamicPPL.PriorContext` (given: `$(typeof(context)))` - """ - throw(ArgumentError(msg)) - end - return new{C}(context) - end +""" + LogPriorWithoutJacobianAccumulator{T}() + +Create a new `LogPriorWithoutJacobianAccumulator` accumulator with the log prior initialized to zero. +""" +LogPriorWithoutJacobianAccumulator{T}() where {T<:Real} = + LogPriorWithoutJacobianAccumulator(zero(T)) +function LogPriorWithoutJacobianAccumulator() + return LogPriorWithoutJacobianAccumulator{DynamicPPL.LogProbType}() end -OptimizationContext(ctx::DynamicPPL.AbstractContext) = OptimizationContext{typeof(ctx)}(ctx) +function Base.show(io::IO, acc::LogPriorWithoutJacobianAccumulator) + return print(io, "LogPriorWithoutJacobianAccumulator($(repr(acc.logp)))") +end -DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsLeaf() +function DynamicPPL.accumulator_name(::Type{<:LogPriorWithoutJacobianAccumulator}) + return :LogPriorWithoutJacobian +end -function DynamicPPL.tilde_assume(ctx::OptimizationContext, dist, vn, vi) - r = vi[vn, dist] - lp = if ctx.context isa Union{DynamicPPL.DefaultContext,DynamicPPL.PriorContext} - # MAP - Distributions.logpdf(dist, r) - else - # MLE - 0 - end - return r, lp, vi +function DynamicPPL.split(::LogPriorWithoutJacobianAccumulator{T}) where {T} + return LogPriorWithoutJacobianAccumulator(zero(T)) +end + +function DynamicPPL.combine( + acc::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator +) + return LogPriorWithoutJacobianAccumulator(acc.logp + acc2.logp) +end + +function Base.:+( + acc1::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator +) + return LogPriorWithoutJacobianAccumulator(acc1.logp + acc2.logp) +end + +function Base.zero(acc::LogPriorWithoutJacobianAccumulator) + return LogPriorWithoutJacobianAccumulator(zero(acc.logp)) +end + +function DynamicPPL.accumulate_assume!!( + acc::LogPriorWithoutJacobianAccumulator, val, logjac, vn, right +) + return acc + LogPriorWithoutJacobianAccumulator(Distributions.logpdf(right, val)) +end +function DynamicPPL.accumulate_observe!!( + acc::LogPriorWithoutJacobianAccumulator, right, left, vn +) + return acc +end + +function Base.convert( + ::Type{LogPriorWithoutJacobianAccumulator{T}}, acc::LogPriorWithoutJacobianAccumulator +) where {T} + return LogPriorWithoutJacobianAccumulator(convert(T, acc.logp)) +end + +function DynamicPPL.convert_eltype( + ::Type{T}, acc::LogPriorWithoutJacobianAccumulator +) where {T} + return LogPriorWithoutJacobianAccumulator(convert(T, acc.logp)) +end + +function getlogprior_without_jacobian(vi::DynamicPPL.AbstractVarInfo) + acc = DynamicPPL.getacc(vi, Val(:LogPriorWithoutJacobian)) + return acc.logp +end + +function getlogjoint_without_jacobian(vi::DynamicPPL.AbstractVarInfo) + return getlogprior_without_jacobian(vi) + DynamicPPL.getloglikelihood(vi) end -function DynamicPPL.tilde_observe( - ctx::OptimizationContext{<:DynamicPPL.PriorContext}, args... +# This is called when constructing a LogDensityFunction, and ensures the VarInfo has the +# right accumulators. +function DynamicPPL.ldf_default_varinfo( + model::DynamicPPL.Model, ::typeof(getlogprior_without_jacobian) ) - return DynamicPPL.tilde_observe(ctx.context, args...) + vi = DynamicPPL.VarInfo(model) + vi = DynamicPPL.setaccs!!(vi, (LogPriorWithoutJacobianAccumulator(),)) + return vi +end + +function DynamicPPL.ldf_default_varinfo( + model::DynamicPPL.Model, ::typeof(getlogjoint_without_jacobian) +) + vi = DynamicPPL.VarInfo(model) + vi = DynamicPPL.setaccs!!( + vi, (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator()) + ) + return vi end """ OptimLogDensity{ M<:DynamicPPL.Model, - V<:DynamicPPL.VarInfo, - C<:OptimizationContext, - AD<:ADTypes.AbstractADType + F<:Function, + V<:DynamicPPL.AbstractVarInfo, + AD<:ADTypes.AbstractADType, } A struct that wraps a single LogDensityFunction. Can be invoked either using ```julia -OptimLogDensity(model, varinfo, ctx; adtype=adtype) +OptimLogDensity(model, varinfo; adtype=adtype) ``` or ```julia -OptimLogDensity(model, ctx; adtype=adtype) +OptimLogDensity(model; adtype=adtype) ``` If not specified, `adtype` defaults to `AutoForwardDiff()`. @@ -129,37 +183,38 @@ the underlying LogDensityFunction at the point `z`. This is done to satisfy the Optim.jl interface. ```julia -optim_ld = OptimLogDensity(model, varinfo, ctx) +optim_ld = OptimLogDensity(model, varinfo) optim_ld(z) # returns -logp ``` """ struct OptimLogDensity{ - M<:DynamicPPL.Model, - V<:DynamicPPL.VarInfo, - C<:OptimizationContext, - AD<:ADTypes.AbstractADType, + M<:DynamicPPL.Model,F<:Function,V<:DynamicPPL.AbstractVarInfo,AD<:ADTypes.AbstractADType } - ldf::DynamicPPL.LogDensityFunction{M,V,C,AD} -end + ldf::DynamicPPL.LogDensityFunction{M,F,V,AD} -function OptimLogDensity( - model::DynamicPPL.Model, - vi::DynamicPPL.VarInfo, - ctx::OptimizationContext; - adtype::ADTypes.AbstractADType=AutoForwardDiff(), -) - return OptimLogDensity(DynamicPPL.LogDensityFunction(model, vi, ctx; adtype=adtype)) -end - -# No varinfo -function OptimLogDensity( - model::DynamicPPL.Model, - ctx::OptimizationContext; - adtype::ADTypes.AbstractADType=AutoForwardDiff(), -) - return OptimLogDensity( - DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx; adtype=adtype) + function OptimLogDensity( + model::DynamicPPL.Model, + getlogdensity::Function, + vi::DynamicPPL.VarInfo; + adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, + ) + return new{typeof(model),typeof(getlogdensity),typeof(vi),typeof(adtype)}( + DynamicPPL.LogDensityFunction(model, getlogdensity, vi; adtype=adtype) + ) + end + function OptimLogDensity( + model::DynamicPPL.Model, + getlogdensity::Function; + adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) + # No varinfo + return OptimLogDensity( + model, + getlogdensity, + DynamicPPL.ldf_default_varinfo(model, getlogdensity); + adtype=adtype, + ) + end end """ @@ -325,10 +380,13 @@ function StatsBase.informationmatrix( # Convert the values to their unconstrained states to make sure the # Hessian is computed with respect to the untransformed parameters. - linked = DynamicPPL.istrans(m.f.ldf.varinfo) + old_ldf = m.f.ldf + linked = DynamicPPL.istrans(old_ldf.varinfo) if linked - new_vi = DynamicPPL.invlink!!(m.f.ldf.varinfo, m.f.ldf.model) - new_f = OptimLogDensity(m.f.ldf.model, new_vi, m.f.ldf.context) + new_vi = DynamicPPL.invlink!!(old_ldf.varinfo, old_ldf.model) + new_f = OptimLogDensity( + old_ldf.model, old_ldf.getlogdensity, new_vi; adtype=old_ldf.adtype + ) m = Accessors.@set m.f = new_f end @@ -339,8 +397,11 @@ function StatsBase.informationmatrix( # Link it back if we invlinked it. if linked - new_vi = DynamicPPL.link!!(m.f.ldf.varinfo, m.f.ldf.model) - new_f = OptimLogDensity(m.f.ldf.model, new_vi, m.f.ldf.context) + invlinked_ldf = m.f.ldf + new_vi = DynamicPPL.link!!(invlinked_ldf.varinfo, invlinked_ldf.model) + new_f = OptimLogDensity( + invlinked_ldf.model, old_ldf.getlogdensity, new_vi; adtype=invlinked_ldf.adtype + ) m = Accessors.@set m.f = new_f end @@ -550,7 +611,12 @@ function estimate_mode( ub=nothing, kwargs..., ) - check_model && DynamicPPL.check_model(model; error_on_failure=true) + if check_model + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(model.context) + ) + DynamicPPL.check_model(spl_model, VarInfo(); error_on_failure=true) + end constraints = ModeEstimationConstraints(lb, ub, cons, lcons, ucons) initial_params = generate_initial_params(model, initial_params, constraints) @@ -560,19 +626,15 @@ function estimate_mode( # Create an OptimLogDensity object that can be used to evaluate the objective function, # i.e. the negative log density. - inner_context = if estimator isa MAP - DynamicPPL.DefaultContext() - else - DynamicPPL.LikelihoodContext() - end - ctx = OptimizationContext(inner_context) + getlogdensity = + estimator isa MAP ? getlogjoint_without_jacobian : DynamicPPL.getloglikelihood # Set its VarInfo to the initial parameters. # TODO(penelopeysm): Unclear if this is really needed? Any time that logp is calculated # (using `LogDensityProblems.logdensity(ldf, x)`) the parameters in the # varinfo are completely ignored. The parameters only matter if you are calling evaluate!! # directly on the fields of the LogDensityFunction - vi = DynamicPPL.VarInfo(model) + vi = DynamicPPL.ldf_default_varinfo(model, getlogdensity) vi = DynamicPPL.unflatten(vi, initial_params) # Link the varinfo if needed. @@ -585,7 +647,7 @@ function estimate_mode( vi = DynamicPPL.link(vi, model) end - log_density = OptimLogDensity(model, vi, ctx) + log_density = OptimLogDensity(model, getlogdensity, vi) prob = Optimization.OptimizationProblem(log_density, adtype, constraints) solution = Optimization.solve(prob, solver; kwargs...) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index b9428af112..d516319684 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -17,12 +17,6 @@ export vi, q_locationscale, q_meanfield_gaussian, q_fullrank_gaussian include("deprecated.jl") -function make_logdensity(model::DynamicPPL.Model) - weight = 1.0 - ctx = DynamicPPL.MiniBatchContext(DynamicPPL.DefaultContext(), weight) - return DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) -end - """ q_initialize_scale( [rng::Random.AbstractRNG,] @@ -68,7 +62,7 @@ function q_initialize_scale( num_max_trials::Int=10, reduce_factor::Real=one(eltype(scale)) / 2, ) - prob = make_logdensity(model) + prob = LogDensityFunction(model) ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) varinfo = DynamicPPL.VarInfo(model) @@ -309,7 +303,7 @@ function vi( ) return AdvancedVI.optimize( rng, - make_logdensity(model), + LogDensityFunction(model), objective, q, n_iterations; diff --git a/test/Project.toml b/test/Project.toml index 0048224d50..7e25379122 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -53,7 +53,7 @@ Combinatorics = "1" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.36.12" +DynamicPPL = "0.37" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1" HypothesisTests = "0.11" @@ -77,3 +77,6 @@ StatsBase = "0.33, 0.34" StatsFuns = "0.9.5, 1" TimerOutputs = "0.5" julia = "1.10" + +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} diff --git a/test/ad.jl b/test/ad.jl index 2f645fab5d..f53dd98358 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -155,35 +155,33 @@ end # child context. function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi) - value, logp, vi = DynamicPPL.tilde_assume( - DynamicPPL.childcontext(context), right, vn, vi - ) + value, vi = DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) check_adtype(context, vi) - return value, logp, vi + return value, vi end function DynamicPPL.tilde_assume( rng::Random.AbstractRNG, context::ADTypeCheckContext, sampler, right, vn, vi ) - value, logp, vi = DynamicPPL.tilde_assume( + value, vi = DynamicPPL.tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, vn, vi ) check_adtype(context, vi) - return value, logp, vi + return value, vi end -function DynamicPPL.tilde_observe(context::ADTypeCheckContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vi) +function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, right, left, vi) + left, vi = DynamicPPL.tilde_observe!!(DynamicPPL.childcontext(context), right, left, vi) check_adtype(context, vi) - return logp, vi + return left, vi end -function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi) - logp, vi = DynamicPPL.tilde_observe( +function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, sampler, right, left, vi) + left, vi = DynamicPPL.tilde_observe!!( DynamicPPL.childcontext(context), sampler, right, left, vi ) check_adtype(context, vi) - return logp, vi + return left, vi end """ @@ -256,8 +254,10 @@ end @testset "model=$(model.f)" for model in DEMO_MODELS rng = StableRNG(123) - ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) - @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) + ) + @test run_ad(spl_model, adtype; test=true, benchmark=false) isa Any end end end @@ -283,8 +283,10 @@ end model, varnames, deepcopy(global_vi) ) rng = StableRNG(123) - ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10))) - @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10))) + ) + @test run_ad(spl_model, adtype; test=true, benchmark=false) isa Any end end end diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index a0d4421869..38baa46fc2 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -113,36 +113,6 @@ using Turing check_gdemo(chn3_contd) end - @testset "Contexts" begin - # Test LikelihoodContext - @model function testmodel1(x) - a ~ Beta() - lp1 = getlogp(__varinfo__) - x[1] ~ Bernoulli(a) - return global loglike = getlogp(__varinfo__) - lp1 - end - model = testmodel1([1.0]) - varinfo = DynamicPPL.VarInfo(model) - model(varinfo, DynamicPPL.SampleFromPrior(), DynamicPPL.LikelihoodContext()) - @test getlogp(varinfo) == loglike - - # Test MiniBatchContext - @model function testmodel2(x) - a ~ Beta() - return x[1] ~ Bernoulli(a) - end - model = testmodel2([1.0]) - varinfo1 = DynamicPPL.VarInfo(model) - varinfo2 = deepcopy(varinfo1) - model(varinfo1, DynamicPPL.SampleFromPrior(), DynamicPPL.LikelihoodContext()) - model( - varinfo2, - DynamicPPL.SampleFromPrior(), - DynamicPPL.MiniBatchContext(DynamicPPL.LikelihoodContext(), 10), - ) - @test isapprox(getlogp(varinfo2) / getlogp(varinfo1), 10) - end - @testset "Prior" begin N = 10_000 @@ -174,21 +144,6 @@ using Turing @test mean(x[:s][1] for x in chains) ≈ 3 atol = 0.11 @test mean(x[:m][1] for x in chains) ≈ 0 atol = 0.1 end - - @testset "#2169" begin - # Not exactly the same as the issue, but similar. - @model function issue2169_model() - if DynamicPPL.leafcontext(__context__) isa DynamicPPL.PriorContext - x ~ Normal(0, 1) - else - x ~ Normal(1000, 1) - end - end - - model = issue2169_model() - chain = sample(StableRNG(seed), model, Prior(), 10) - @test all(mean(chain[:x]) .< 5) - end end @testset "chain ordering" begin diff --git a/test/mcmc/external_sampler.jl b/test/mcmc/external_sampler.jl index e2dc417d09..6a6aebddb0 100644 --- a/test/mcmc/external_sampler.jl +++ b/test/mcmc/external_sampler.jl @@ -20,7 +20,9 @@ function initialize_nuts(model::DynamicPPL.Model) linked_vi = DynamicPPL.link!!(vi, model) # Create a LogDensityFunction - f = DynamicPPL.LogDensityFunction(model, linked_vi; adtype=Turing.DEFAULT_ADTYPE) + f = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint, linked_vi; adtype=Turing.DEFAULT_ADTYPE + ) # Choose parameter dimensionality and initial parameter value D = LogDensityProblems.dimension(f) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 8832e5fe7b..f78c7a0237 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -171,6 +171,8 @@ using Turing @test Array(res1) == Array(res2) == Array(res3) end + # TODO(mhauru) Do we give up being able to sample from only prior/likelihood like this, + # or do we implement some way to pass `whichlogprob=:LogPrior` through `sample`? @testset "prior" begin # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance # which means that it's _very_ difficult to find a good tolerance in the test below:) diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index add2e7404a..3bbb83db5f 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -262,6 +262,8 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) @test !DynamicPPL.islinked(vi) end + # TODO(mhauru) Do we give up being able to sample from only prior/likelihood like this, + # or do we implement some way to pass `whichlogprob=:LogPrior` through `sample`? @testset "prior" begin alg = MH() gdemo_default_prior = DynamicPPL.contextualize( diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index 2acb7edc55..9909ee149b 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -24,28 +24,7 @@ using Turing hasstats(result) = result.optim_result.stats !== nothing # Issue: https://discourse.julialang.org/t/two-equivalent-conditioning-syntaxes-giving-different-likelihood-values/100320 - @testset "OptimizationContext" begin - # Used for testing how well it works with nested contexts. - struct OverrideContext{C,T1,T2} <: DynamicPPL.AbstractContext - context::C - logprior_weight::T1 - loglikelihood_weight::T2 - end - DynamicPPL.NodeTrait(::OverrideContext) = DynamicPPL.IsParent() - DynamicPPL.childcontext(parent::OverrideContext) = parent.context - DynamicPPL.setchildcontext(parent::OverrideContext, child) = - OverrideContext(child, parent.logprior_weight, parent.loglikelihood_weight) - - # Only implement what we need for the models above. - function DynamicPPL.tilde_assume(context::OverrideContext, right, vn, vi) - value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) - return value, context.logprior_weight, vi - end - function DynamicPPL.tilde_observe(context::OverrideContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) - return context.loglikelihood_weight, vi - end - + @testset "OptimLogDensity and contexts" begin @model function model1(x) μ ~ Uniform(0, 2) return x ~ LogNormal(μ, 1) @@ -62,48 +41,44 @@ using Turing @testset "With ConditionContext" begin m1 = model1(x) m2 = model2() | (x=x,) - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == - Turing.Optimisation.OptimLogDensity(m2, ctx)(w) + # Doesn't matter if we use getlogjoint or getlogjoint_without_jacobian since the + # VarInfo isn't linked. + ld1 = Turing.Optimisation.OptimLogDensity( + m1, Turing.Optimisation.getlogjoint_without_jacobian + ) + ld2 = Turing.Optimisation.OptimLogDensity(m2, DynamicPPL.getlogjoint) + @test ld1(w) == ld2(w) end @testset "With prefixes" begin vn = @varname(inner) m1 = prefix(model1(x), vn) m2 = prefix((model2() | (x=x,)), vn) - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == - Turing.Optimisation.OptimLogDensity(m2, ctx)(w) - end - - @testset "Weighted" begin - function override(model) - return DynamicPPL.contextualize( - model, OverrideContext(model.context, 100, 1) - ) - end - m1 = override(model1(x)) - m2 = override(model2() | (x=x,)) - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == - Turing.Optimisation.OptimLogDensity(m2, ctx)(w) + ld1 = Turing.Optimisation.OptimLogDensity( + m1, Turing.Optimisation.getlogjoint_without_jacobian + ) + ld2 = Turing.Optimisation.OptimLogDensity(m2, DynamicPPL.getlogjoint) + @test ld1(w) == ld2(w) end - @testset "Default, Likelihood, Prior Contexts" begin + @testset "Joint, prior, and likelihood" begin m1 = model1(x) - defctx = Turing.Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - llhctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - prictx = Turing.Optimisation.OptimizationContext(DynamicPPL.PriorContext()) a = [0.3] + ld_joint = Turing.Optimisation.OptimLogDensity( + m1, Turing.Optimisation.getlogjoint_without_jacobian + ) + ld_prior = Turing.Optimisation.OptimLogDensity( + m1, Turing.Optimisation.getlogprior_without_jacobian + ) + ld_likelihood = Turing.Optimisation.OptimLogDensity( + m1, DynamicPPL.getloglikelihood + ) + @test ld_joint(a) == ld_prior(a) + ld_likelihood(a) - @test Turing.Optimisation.OptimLogDensity(m1, defctx)(a) == - Turing.Optimisation.OptimLogDensity(m1, llhctx)(a) + - Turing.Optimisation.OptimLogDensity(m1, prictx)(a) - - # test that PriorContext is calculating the right thing - @test Turing.Optimisation.OptimLogDensity(m1, prictx)([0.3]) ≈ + # test that the prior accumulator is calculating the right thing + @test Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogprior)([0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), 0.3) - @test Turing.Optimisation.OptimLogDensity(m1, prictx)([-0.3]) ≈ + @test Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogprior)([-0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), -0.3) end end @@ -651,8 +626,7 @@ using Turing return nothing end m = saddle_model() - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - optim_ld = Turing.Optimisation.OptimLogDensity(m, ctx) + optim_ld = Turing.Optimisation.OptimLogDensity(m, DynamicPPL.getloglikelihood) vals = Turing.Optimisation.NamedArrays.NamedArray([0.0, 0.0]) m = Turing.Optimisation.ModeResult(vals, nothing, 0.0, optim_ld) ct = coeftable(m)