- 
                Notifications
    You must be signed in to change notification settings 
- Fork 36
          Remove eltype, matchingvalue, get_matching_type
          #1015
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| Benchmark Report for Commit 4a29a2aComputer InformationBenchmark Results | 
| return nothing | ||
| end | ||
|  | ||
| # TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because matchingvalue gets called on all the model function's arguments, and types can be arguments to the model as well, e.g.
@model function f(x, T) ... end
model = f(1.0, Float64)| # TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we | ||
| # are happy to return `value` as-is? | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change was made here:
The motivation is here:
TuringLang/Turing.jl#1464 (comment)
This has to do with some subtle mutation behaviour. For example
@model function f(x)
    x[1] ~ Normal()
endIf model = f([1.0]), the tilde statement is an observe, and thus even if you reassign to x[1]  it doesn't change the value of x. This is the !hasmissing branch, and since overwriting is a no-op, we don't need to deepcopy it.
If model = f([missing]) - the tilde statement is now an assume, and when you run the model it will sample a new value for x[1] and set that value in x. Then if you rerun the model x[1] is no longer missing. This is the case where deepcopy is triggered.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So apart from the deepcopy to avoid aliasing, the other place where matchingvalue does something meaningful is
@model function f(y, ::Type{T}=Float64) where {T}
    x = Vector{T}(undef, length(y))
    for i in eachindex(y)
        x[i] ~ Normal()
        y[i] ~ Normal(x[i])
    end
end
model = f([1.0])If you just evaluate this normally with floats, it's all good. Nothing special needs to happen.
If you evaluate this with ReverseDiff, then things need to change. Specifically:
- xneeds to become a vector of- TrackedRealsrather than a vector of Floats.
- In order to accomplish this, the ARGUMENT to the model needs to change: even though TSEEMS to be specified as Float64, in fact,matchingvaluehijacks it to turn it intoTrackedRealwhen callingmodel().
- How does matchingvalueknow that it needs to become a TrackedReal? Simple - when you calllogdensity_and_gradientit callsunflattento set the parameters (which will be TrackedReals) in the varinfo.matchingvaluethen looks inside the varinfo to see if the varinfo containsTrackedReals! Henceeltype(vi)🙃
It actually gets a bit more complicated. When you define the model, the @model macro already hijacks it to turn T into TypeWrap{Float64}(), and then when you actually evaluate the model matchingvalue hijacks it even further to turn it into TypeWrap{TrackedReal}(). Not sure why TypeWrap is needed but apparently it's something to do with avoiding DataType.
ForwardDiff actually works just fine on this PR. I don't know why, but I also remember there was a talk I gave where we were surprised that actually ForwardDiff NUTS worked fine without special ::Type{T}=Float64 stuff, so that is consistent with this observation.
So this whole thing pretty much only exists to make ReverseDiff happy.
To get around this, I propose that we drop compatibility with ReverseDiff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, for most models, ForwardDiff and ReverseDiff still work because of this special nice behaviour:
julia> x = Float64[1.0, 2.0]
2-element Vector{Float64}:
 1.0
 2.0
julia> x[1] = ForwardDiff.Dual(3.0) # x[1] ~ dist doesn't do this
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{Nothing, Float64, 0})
The type `Float64` exists, but no method is defined for this combination of argument types when trying to construct it.
julia> x = Accessors.set(x, (@optic _[1]), ForwardDiff.Dual(3.0)) # x[1] ~ dist actually does this!
2-element Vector{ForwardDiff.Dual{Nothing, Float64, 0}}:
 Dual{Nothing}(3.0)
 Dual{Nothing}(2.0)There is only one erroring test in CI, which happens because the model explicitly includes the assignment x[i] = ... rather than a tilde-statement x[i] ~ .... Changing the assignment to use Accessors.set makes it work just fine.
BUT there are correctness issues with ReverseDiff (not errors), and I have no clue where those stem from. And really interestingly, it's only a problem for one of the demo models, not any of the others, even though many of them use the Type{T} syntax.
        
          
                src/model.jl
              
                Outdated
          
        
      | :($matchingvalue(varinfo, model.args.$var)...) | ||
| :(deepcopy(model.args.$var)...) | ||
| else | ||
| :($matchingvalue(varinfo, model.args.$var)) | ||
| :(deepcopy(model.args.$var)) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So matchingvalue used to deepcopy things sometimes. Right now I work around this by indiscriminately deepcopying. This is a Bad Thing and we should definitely have more careful rules about when something needs to be deepcopied. However, I don't believe that such rules need to use the whole matching_type machinery.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indiscriminately deepcopying here breaks ReverseDiff. See comment below: #1015 (comment)
| Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@            Coverage Diff             @@
##             main    #1015      +/-   ##
==========================================
- Coverage   82.16%   80.82%   -1.34%     
==========================================
  Files          38       38              
  Lines        3935     3891      -44     
==========================================
- Hits         3233     3145      -88     
- Misses        702      746      +44     ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
 | 
| Pull Request Test Coverage Report for Build 16818637635Details
 
 💛 - Coveralls | 
| DynamicPPL.jl documentation for PR #1015 is available at: | 
| ReverseDiff correctness issue with this PR: using DynamicPPL, Distributions, FiniteDifferences, ReverseDiff, ADTypes, LinearAlgebra, Random
using DynamicPPL.TestUtils.AD: run_ad, WithBackend
@model function inner(m, x)
    @show m
    return x ~ Normal(m[1])
end
@model function outer(x)
    # m has to be vector-valued for it to fail
    m ~ MvNormal(zeros(1), I)
    # If you use this line it works
    # x ~ Normal(m[1])
    # This line is seemingly equivalent but fails
    t ~ to_submodel(inner(m, x))
end
model = outer(1.5)
run_ad(
    model,
    AutoReverseDiff();
    test=WithBackend(AutoFiniteDifferences(fdm=central_fdm(5, 1))),
    rng=Xoshiro(468)
); | 
| Ironically, removing the  | 
| If this doesn't get merged, then I would at least like to have the lessons you learned here recorded somewhere if possible. | 
| Yupp, definitely. This PR is basically me liveblogging as I find things out 😄 but at the very least, I'm sure we could improve those docstrings / add comments. | 
Half-hearted attempt. I'd like to see what breaks and why.