-
Notifications
You must be signed in to change notification settings - Fork 36
Use NoCache
to improve set_to_zero!!
performance with Mooncake
#975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sunxd3
wants to merge
8
commits into
main
Choose a base branch
from
mooncake-nocache-optimization
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+427
−3
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
e93458c
use `NoCache` to improve `set_to_zero!!` performance with Mooncake
sunxd3 18f4c73
use concrete version number for history note
sunxd3 5c79686
fix test errors
sunxd3 7b99643
Merge branch 'main' into mooncake-nocache-optimization
sunxd3 66f453c
resolve CI error
sunxd3 aabc844
refactor
sunxd3 92f935d
refactor more; add additional test
sunxd3 de57edd
remove Mooncake from test project
sunxd3 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,184 @@ | ||
module DynamicPPLMooncakeExt | ||
|
||
__precompile__(false) | ||
|
||
using DynamicPPL: DynamicPPL, istrans | ||
using Mooncake: Mooncake | ||
import Mooncake: set_to_zero!! | ||
using Mooncake: NoTangent, Tangent, MutableTangent, NoCache, set_to_zero_internal!! | ||
|
||
# This is purely an optimisation. | ||
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg} | ||
|
||
# ======================= | ||
# `Mooncake.set_to_zero!!` optimization with `NoCache` | ||
# ======================= | ||
|
||
""" | ||
Check if a tangent has the expected structure for a given type. | ||
""" | ||
function has_expected_structure( | ||
x, expected_type::Type{<:Union{Tangent,MutableTangent}}, expected_fields | ||
) | ||
x isa expected_type || return false | ||
hasfield(typeof(x), :fields) || return false | ||
|
||
fields = x.fields | ||
if expected_fields isa Tuple | ||
# Exact match required | ||
propertynames(fields) == expected_fields || return false | ||
else | ||
# All expected fields must be present | ||
all(f in propertynames(fields) for f in expected_fields) || return false | ||
end | ||
|
||
return true | ||
end | ||
|
||
function is_dppl_ldf_tangent(x) | ||
has_expected_structure(x, Tangent, (:model, :varinfo, :context, :adtype, :prep)) || | ||
return false | ||
|
||
fields = x.fields | ||
is_dppl_varinfo_tangent(fields.varinfo) || return false | ||
is_dppl_model_tangent(fields.model) || return false | ||
|
||
return true | ||
end | ||
|
||
function is_dppl_varinfo_tangent(x) | ||
return has_expected_structure(x, Tangent, (:metadata, :logp, :num_produce)) | ||
end | ||
|
||
function is_dppl_model_tangent(x) | ||
return has_expected_structure(x, Tangent, (:f, :args, :defaults, :context)) | ||
end | ||
|
||
function is_dppl_metadata_tangent(x) | ||
# Metadata can be either: | ||
# 1. A MutableTangent with the expected fields (for single metadata) | ||
# 2. A NamedTuple where each value is a Tangent with the expected fields | ||
|
||
# Check for MutableTangent case | ||
if has_expected_structure( | ||
x, MutableTangent, (:idcs, :vns, :ranges, :vals, :dists, :orders, :flags) | ||
) | ||
return true | ||
end | ||
|
||
# Check for NamedTuple case (multiple metadata) | ||
if x isa NamedTuple | ||
# Each value should be a Tangent with metadata fields | ||
for var_metadata in values(x) | ||
if !has_expected_structure( | ||
var_metadata, | ||
Tangent, | ||
(:idcs, :vns, :ranges, :vals, :dists, :orders, :flags), | ||
) | ||
return false | ||
end | ||
end | ||
return true | ||
end | ||
|
||
return false | ||
end | ||
|
||
""" | ||
has_circular_reference_risk(x) | ||
|
||
Main entry point for detecting circular reference patterns that require caching. | ||
""" | ||
function has_circular_reference_risk(x) | ||
if is_dppl_ldf_tangent(x) | ||
# Check model function for closure patterns with circular refs | ||
model_f = x.fields.model.fields.f | ||
return is_closure_with_circular_refs(model_f) | ||
elseif is_dppl_varinfo_tangent(x) | ||
return check_for_ref_fields(x) | ||
end | ||
|
||
# For unknown types, do a shallow check for PossiblyUninitTangent{Any} | ||
return x isa Mooncake.PossiblyUninitTangent{Any} | ||
end | ||
|
||
function is_closure_with_circular_refs(x) | ||
# Check if MutableTangent contains PossiblyUninitTangent{Any} | ||
if x isa MutableTangent && hasfield(typeof(x), :fields) | ||
hasfield(typeof(x.fields), :contents) && | ||
x.fields.contents isa Mooncake.PossiblyUninitTangent{Any} && | ||
return true | ||
end | ||
|
||
# For Tangent, only check immediate fields (no deep recursion) | ||
if x isa Tangent && hasfield(typeof(x), :fields) | ||
for (_, fval) in pairs(x.fields) | ||
if fval isa MutableTangent && | ||
hasfield(typeof(fval), :fields) && | ||
hasfield(typeof(fval.fields), :contents) && | ||
fval.fields.contents isa Mooncake.PossiblyUninitTangent{Any} | ||
return true | ||
end | ||
end | ||
end | ||
|
||
return false | ||
end | ||
|
||
function check_for_ref_fields(x) | ||
# Check if it's a VarInfo tangent | ||
is_dppl_varinfo_tangent(x) || return false | ||
|
||
# Check if the logp field contains a Ref-like tangent structure | ||
hasfield(typeof(x.fields), :logp) || return false | ||
logp_tangent = x.fields.logp | ||
|
||
# Ref types in tangents often appear as MutableTangent with circular references | ||
return logp_tangent isa MutableTangent | ||
end | ||
|
||
function is_safe_dppl_type(x) | ||
# Metadata is always safe | ||
is_dppl_metadata_tangent(x) && return true | ||
|
||
# Model tangents without closures are safe | ||
if is_dppl_model_tangent(x) | ||
!is_closure_with_circular_refs(x.fields.f) && return true | ||
end | ||
|
||
# VarInfo without Ref fields is safe | ||
if is_dppl_varinfo_tangent(x) | ||
!check_for_ref_fields(x) && return true | ||
end | ||
|
||
return false | ||
end | ||
|
||
""" | ||
determine_cache_strategy(x) | ||
|
||
Determines the appropriate caching strategy for a given tangent. | ||
Returns either `NoCache()` for safe types or `IdDict{Any,Bool}()` for types with circular reference risk. | ||
""" | ||
function determine_cache_strategy(x) | ||
# Fast path: check for known circular reference patterns | ||
has_circular_reference_risk(x) && return IdDict{Any,Bool}() | ||
|
||
# Check for DynamicPPL types that can safely use NoCache | ||
is_safe_dppl_type(x) && return NoCache() | ||
|
||
# Special case: LogDensityFunction without problematic patterns can use NoCache | ||
if is_dppl_ldf_tangent(x) | ||
return NoCache() | ||
end | ||
|
||
# Default to safe caching for unknown types | ||
return IdDict{Any,Bool}() | ||
end | ||
|
||
function Mooncake.set_to_zero!!(x) | ||
cache = determine_cache_strategy(x) | ||
return set_to_zero_internal!!(cache, x) | ||
end | ||
|
||
end # module |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to do this because I am overloading
Mooncake.set_to_zero!!
at the bottom of this file.Alternatively, I can define
set_to_zero!!
only on tangent types, but it might be trivial as these can be deeply recursive functions. So careful implementation might need to define this function for many types.Any better ideas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the problem in defining
set_to_zero!!
for all possible types that DPPL might bring up, but the pretty heavy-handed type piracy of definingMooncake.set_to_zero!!(x)
does trouble me. For instance, might that cause a lot of invalidations, and thus a lot more Mooncake recompilation?Also, the fact that the code relies on checking field names means that if someone just happens to define a type with the same field names, the behaviour of Mooncake on those types would depend on whether DynamicPPL is loaded in the same environment. It feels unlikely to happen, but it could lead to some truly horrendous bugs to track if it did, and also just feels like we are messing with other people's code in an inconsiderate manner.
I don't really understand the context here, but it look like this is dealing with some Mooncake issues related to circular references. Any chance that some of the machinery for dealing with that (like declaring certain types as safe/unsafe) could be implemented in Mooncake itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with and share the concerns.
The issue, from my POV, is that, Mooncake need to handle potential cases of circular references, by paying little price in a slightly conservative manner. It turns out, for simple DynamicPPL models, the little price of initializing and looking up in IdDict matter and make benchmarks look bad.
Ideally, there would be systematic changes in Mooncake so that one can tell Mooncake "I promise there will not be circular ref, so no need for IdDict".
As for the invalidations, I have to admit that I don't know how bad it would be. Given that for Mooncake, Julia would precompile
set_to_zero_internal!!
(https://github.com/chalk-lab/Mooncake.jl/blob/a26b5c35c55d1e98b9e8c6bfafbbe3dc55784140/src/tangents.jl#L728-L733). Given thatset_to_zero!!
at the moment is pretty much just an alias, I don't expect the cost is too grand.ref chalk-lab/Mooncake.jl#552 (comment)