- 
                Notifications
    
You must be signed in to change notification settings  - Fork 37
 
          InitContext, part 4 - Use init!! to replace evaluate_and_sample!!, predict, returned, and initialize_values
          #984
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
485a525
              7a05ec5
              b00e284
              5ed975c
              2706239
              84e5e55
              7f188b9
              f7ac1b1
              2041927
              d9292ad
              70bb2c4
              bc04355
              2cfc297
              891b4b3
              3bb7ade
              1bdb76e
              39b958d
              907d24f
              07946a7
              afdb173
              c641923
              956ed54
              8d13c30
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -28,7 +28,7 @@ end | |
| 
     | 
||
| function _check_varname_indexing(c::MCMCChains.Chains) | ||
| return DynamicPPL.supports_varname_indexing(c) || | ||
| error("Chains do not support indexing using `VarName`s.") | ||
| error("This `Chains` object does not support indexing using `VarName`s.") | ||
| end | ||
| 
     | 
||
| function DynamicPPL.getindex_varname( | ||
| 
        
          
        
         | 
    @@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains) | |
| return keys(c.info.varname_to_symbol) | ||
| end | ||
| 
     | 
||
| function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx) | ||
                
      
                  mhauru marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| _check_varname_indexing(c) | ||
| d = Dict{DynamicPPL.VarName,Any}() | ||
                
      
                  penelopeysm marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| for vn in DynamicPPL.varnames(c) | ||
| d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) | ||
| end | ||
| return d | ||
| end | ||
                
       | 
||
| 
     | 
||
| """ | ||
| predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) | ||
| 
     | 
||
| 
          
            
          
           | 
    @@ -114,9 +123,15 @@ function DynamicPPL.predict( | |
| 
     | 
||
| iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) | ||
| predictive_samples = map(iters) do (sample_idx, chain_idx) | ||
| DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) | ||
| varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo)) | ||
| 
     | 
||
| # Extract values from the chain | ||
| values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) | ||
| # Resample any variables that are not present in `values_dict` | ||
| _, varinfo = DynamicPPL.init!!( | ||
| rng, | ||
| model, | ||
| varinfo, | ||
| DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), | ||
| ) | ||
| vals = DynamicPPL.values_as_in_model(model, false, varinfo) | ||
| varname_vals = mapreduce( | ||
| collect, | ||
| 
          
            
          
           | 
    @@ -248,13 +263,16 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha | |
| varinfo = DynamicPPL.VarInfo(model) | ||
| iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) | ||
| return map(iters) do (sample_idx, chain_idx) | ||
| # TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702. | ||
| # Update the varinfo with the current sample and make variables not present in `chain` | ||
| # to be sampled. | ||
| DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx) | ||
| # NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to | ||
| # `deepcopy` the `varinfo` before passing it to the `model`. | ||
| model(deepcopy(varinfo)) | ||
| # Extract values from the chain | ||
| values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx) | ||
| # Resample any variables that are not present in `values_dict`, and | ||
| # return the model's retval. | ||
| retval, _ = DynamicPPL.init!!( | ||
| model, | ||
| varinfo, | ||
| DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), | ||
| ) | ||
| retval | ||
| end | ||
| end | ||
| 
     | 
||
| 
          
            
          
           | 
    ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -850,7 +850,7 @@ end | |
| # ^ Weird Documenter.jl bug means that we have to write the two above separately | ||
| # as it can only detect the `function`-less syntax. | ||
| function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo()) | ||
| return first(evaluate_and_sample!!(rng, model, varinfo)) | ||
| return first(init!!(rng, model, varinfo)) | ||
| end | ||
| 
     | 
||
| """ | ||
| 
        
          
        
         | 
    @@ -863,32 +863,6 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) | |
| return Threads.nthreads() > 1 | ||
| end | ||
| 
     | 
||
| """ | ||
| evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler]) | ||
| 
     | 
||
| Evaluate the `model` with the given `varinfo`, but perform sampling during the | ||
| evaluation using the given `sampler` by wrapping the model's context in a | ||
| `SamplingContext`. | ||
| 
     | 
||
| If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref). | ||
| 
     | 
||
| Returns a tuple of the model's return value, plus the updated `varinfo` object. | ||
| """ | ||
| function evaluate_and_sample!!( | ||
| rng::Random.AbstractRNG, | ||
| model::Model, | ||
| varinfo::AbstractVarInfo, | ||
| sampler::AbstractSampler=SampleFromPrior(), | ||
| ) | ||
| sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context)) | ||
| return evaluate!!(sampling_model, varinfo) | ||
| end | ||
| function evaluate_and_sample!!( | ||
| model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior() | ||
| ) | ||
| return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) | ||
| end | ||
| 
     | 
||
| """ | ||
| init!!( | ||
| [rng::Random.AbstractRNG,] | ||
| 
        
          
        
         | 
    @@ -897,10 +871,10 @@ end | |
| [init_strategy::AbstractInitStrategy=InitFromPrior()] | ||
| ) | ||
| 
     | 
||
| Evaluate the `model` and replace the values of the model's random variables in | ||
| the given `varinfo` with new values using a specified initialisation strategy. | ||
| If the values in `varinfo` are not already present, they will be added using | ||
| that same strategy. | ||
| Evaluate the `model` and replace the values of the model's random variables | ||
| in the given `varinfo` with new values, using a specified initialisation strategy. | ||
| If the values in `varinfo` are not set, they will be added. | ||
                
      
                  penelopeysm marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| using a specified initialisation strategy. | ||
| 
     | 
||
| If `init_strategy` is not provided, defaults to InitFromPrior(). | ||
                
      
                  penelopeysm marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| 
     | 
||
| 
          
            
          
           | 
    @@ -1051,11 +1025,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) | |
| Generate a sample of type `T` from the prior distribution of the `model`. | ||
| """ | ||
| function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} | ||
| x = last( | ||
| evaluate_and_sample!!( | ||
| rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()) | ||
| ), | ||
| ) | ||
| x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()))) | ||
| return values_as(x, T) | ||
| end | ||
| 
     | 
||
| 
          
            
          
           | 
    @@ -1227,25 +1197,8 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC | |
| end | ||
| end | ||
| 
     | 
||
| """ | ||
| predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) | ||
| 
     | 
||
| Generate samples from the posterior predictive distribution by evaluating `model` at each set | ||
| of parameter values provided in `chain`. The number of posterior predictive samples matches | ||
| the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values | ||
| and the predicted values. | ||
| """ | ||
| function predict( | ||
| rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo} | ||
| ) | ||
| varinfo = DynamicPPL.VarInfo(model) | ||
| return map(chain) do params_varinfo | ||
| vi = deepcopy(varinfo) | ||
| DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) | ||
| model(rng, vi) | ||
| return vi | ||
| end | ||
| end | ||
| # Implemented & documented in DynamicPPLMCMCChainsExt | ||
| function predict end | ||
| 
         
      Comment on lines
    
      -1230
     to 
      +1201
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was discussed at one of the meetings and we decided we didn't care enough about the   | 
||
| 
     | 
||
| """ | ||
| returned(model::Model, parameters::NamedTuple) | ||
| 
          
            
          
           | 
    ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused by the comments in this function because as far as I can tell it only ever tested sampling, not both sampling and evaluation. (That was also true going further back e.g. in v0.36)
This PR thus also changes the implementation of this function to test both evaluation and sampling (i.e. initialisation) and if either fails, it will return the untyped varinfo.
Sorry I had to make this change in this PR. There were a few unholy tests where one would end up evaluating a model with a
SamplingContext{<:InitContext}, which would error unless I introduced special code to handle it, and I didn't really want to do that. JETExt was one of those unholy scenarios.