From 0aad616b7773cafb9ea18e2aed2f70ac8ba3b8da Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Jan 2023 19:40:06 +0000 Subject: [PATCH 01/17] introduction of specialize_make_closure macro --- src/DistributionsAD.jl | 40 +-------------- src/common.jl | 109 +++++++++++++++++++++++++++++++++++++++++ src/lazyarrays.jl | 48 ++++++++++++++++++ 3 files changed, 158 insertions(+), 39 deletions(-) create mode 100644 src/lazyarrays.jl diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index 9be245a..4770ea0 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -80,45 +80,7 @@ include("zygote.jl") end @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin - using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray - - const LazyVectorOfUnivariate{ - S<:ValueSupport, - T<:UnivariateDistribution{S}, - Tdists<:BroadcastVector{T}, - } = VectorOfUnivariate{S,T,Tdists} - - function Distributions._logpdf( - dist::LazyVectorOfUnivariate, - x::AbstractVector{<:Real}, - ) - return sum(copy(logpdf.(dist.v, x))) - end - - function Distributions.logpdf( - dist::LazyVectorOfUnivariate, - x::AbstractMatrix{<:Real}, - ) - size(x, 1) == length(dist) || - throw(DimensionMismatch("Inconsistent array dimensions.")) - return vec(sum(copy(logpdf.(dists, x)), dims = 1)) - end - - const LazyMatrixOfUnivariate{ - S<:ValueSupport, - T<:UnivariateDistribution{S}, - Tdists<:BroadcastArray{T,2}, - } = MatrixOfUnivariate{S,T,Tdists} - - function Distributions._logpdf( - dist::LazyMatrixOfUnivariate, - x::AbstractMatrix{<:Real}, - ) - return sum(copy(logpdf.(dist.dists, x))) - end - - lazyarray(f, x...) = LazyArray(Base.broadcasted(f, x...)) - export lazyarray + include("lazyarrays.jl") end end diff --git a/src/common.jl b/src/common.jl index ee094ba..15dda4f 100644 --- a/src/common.jl +++ b/src/common.jl @@ -48,3 +48,112 @@ parameterless_type(x) = parameterless_type(typeof(x)) parameterless_type(x::Type) = __parameterless_type(x) @non_differentiable adapt_randn(::Any...) + +""" + make_closure(f, g) + +Return a closure of the form `(x, args...) -> f(g(args...), x)`. + +# Examples + +This is particularly useful when one wants to avoid broadcasting over constructors +which can sometimes cause issues with type-inference, in particular when combined +with reverse-mode AD frameworks. + +```juliarepl +julia> using DistributionsAD, Distributions, ReverseDiff, BenchmarkTools + +julia> const data = randn(1000); + +julia> x = randn(length(data)); + +julia> f(x) = sum(logpdf.(Normal.(x), data)) +f (generic function with 2 methods) + +julia> @btime ReverseDiff.gradient(\$f, \$x); + 848.759 μs (14605 allocations: 521.84 KiB) + +julia> # Much faster with ReverseDiff.jl. + g(x) = let g_inner = DistributionsAD.make_closure(logpdf, Normal) + sum(g_inner.(data, x)) + end +g (generic function with 1 method) + +julia> @btime ReverseDiff.gradient(\$g, \$x); + 17.460 μs (17 allocations: 71.52 KiB) +``` + +See https://github.com/TuringLang/Turing.jl/issues/1934 more further discussion. + +# Notes +To really go "vrooom!\" one needs to specialize on the arguments, e.g. if one +has a function `myfunc` then we need to define + +```julia +make_closure(::typeof(myfunc), ::Type{D}) where {D} = myfunc(D(args...), x) +``` + +This can also be done using `DistributionsAD.@specialize_make_closure`: + +```julia +julia> mylogpdf(d, x) = logpdf(d, x) +mylogpdf (generic function with 1 method) + +julia> h(x) = let inner = DistributionsAD.make_closure(mylogpdf, Normal) + sum(inner.(data, x)) + end +h (generic function with 1 method) + +julia> @btime ReverseDiff.gradient(\$h, \$x); + 1.220 ms (37011 allocations: 1.42 MiB) + +julia> DistributionsAD.@specialize_make_closure mylogpdf + +julia> @btime ReverseDiff.gradient(\$h, \$x); + 17.038 μs (17 allocations: 71.52 KiB) +``` +""" +make_closure(f, g) = (x, args...) -> f(g(args...), x) +make_closure(f, ::Type{D}) where {D} = (x, args...) -> f(D(args...), x) + + +""" + has_specialized_make_closure(f, g) + +Return `true` if there exists a specialized `make_closure(f, g)` implementation. +""" +has_specialized_make_closure(f, g) = false + +# To go vroooom we need to specialize on the first argument, thus ensuring that +# a different closure is constructed for each method. +""" + @specialize_make_closure(f) + +Define `make_closure` and `has_specialized_make_closure` for first first argument being `f` +and second argument being a type. +""" +macro specialize_make_closure(f) + return quote + $(DistributionsAD).make_closure(::typeof($(esc(f))), ::Type{D}) where {D} = (x, args...) -> $(esc(f))(D(args...), x) + $(DistributionsAD).has_specialized_make_closure(::typeof($(esc(f))), ::Type{D}) where {D} = true + end +end + +""" + @specialize_make_closure(f, g) + +Define `make_closure` and `has_specialized_make_closure` for first first argument being `f` +and second argument being `g`. +""" +macro specialize_make_closure(f, g) + return quote + $(DistributionsAD).make_closure(::typeof($(esc(f))), ::typeof($(esc(g)))) = (x, args...) -> $(esc(f))($(esc(g))(args...), x) + $(DistributionsAD).has_specialized_make_closure(::typeof($(esc(f))), ::typeof{$(esc(g))}) = true + end +end + +@specialize_make_closure Distributions.pdf +@specialize_make_closure Distributions.logpdf +@specialize_make_closure Distributions.loglikelihood +@specialize_make_closure Distributions.cdf +@specialize_make_closure Distributions.logcdf diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl new file mode 100644 index 0000000..e2500f9 --- /dev/null +++ b/src/lazyarrays.jl @@ -0,0 +1,48 @@ +using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray + +const LazyVectorOfUnivariate{ + S<:ValueSupport, + T<:UnivariateDistribution{S}, + Tdists<:BroadcastVector{T}, +} = VectorOfUnivariate{S,T,Tdists} + +_inner_constructor(::Type{<:BroadcastVector{<:Any,Type{D}}}) where {D} = D + +function Distributions._logpdf( + dist::LazyVectorOfUnivariate, + x::AbstractVector{<:Real}, +) + # TODO: Implement chain rule for `LazyArray` constructor to support Zygote. + f = make_closure(logpdf, _inner_constructor(typeof(dist.v))) + # TODO: Make use of `sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)))` once + # we've addressed performance issues in ReverseDiff.jl. + return sum(f.(x, dist.v.args...)) +end + +function Distributions.logpdf( + dist::LazyVectorOfUnivariate, + x::AbstractMatrix{<:Real}, + ) + size(x, 1) == length(dist) || + throw(DimensionMismatch("Inconsistent array dimensions.")) + f = make_closure(logpdf, _inner_constructor(typeof(dist.v))) + return vec(sum(f.(x, dist.v.args...), dims = 1)) +end + +const LazyMatrixOfUnivariate{ + S<:ValueSupport, + T<:UnivariateDistribution{S}, + Tdists<:BroadcastArray{T,2}, +} = MatrixOfUnivariate{S,T,Tdists} + +function Distributions._logpdf( + dist::LazyMatrixOfUnivariate, + x::AbstractMatrix{<:Real}, +) + f = make_closure(logpdf, _inner_constructor(typeof(dist.v))) + + return sum(f.(x, dist.v.args)) +end + +lazyarray(f, x...) = BroadcastArray(f, x...) +export lazyarray From cd9c845914b1f73161140f0f463701036cb84e24 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Jan 2023 19:44:03 +0000 Subject: [PATCH 02/17] added rrule for BroadcastArray to ensure Zygote compat --- src/lazyarrays.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl index e2500f9..ba0cc79 100644 --- a/src/lazyarrays.jl +++ b/src/lazyarrays.jl @@ -46,3 +46,8 @@ end lazyarray(f, x...) = BroadcastArray(f, x...) export lazyarray + +# Necessary to make `BroadcastArray` work nicely with Zygote. +function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::Type{LazyArrays.BroadcastArray}, f, args...) + return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...) +end From 3cfa86ebdee630f81e59e61b0d8091ffee579653 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Jan 2023 19:53:25 +0000 Subject: [PATCH 03/17] use make_closure in logpdf only if specialized --- src/lazyarrays.jl | 42 ++++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl index ba0cc79..dc5e354 100644 --- a/src/lazyarrays.jl +++ b/src/lazyarrays.jl @@ -1,5 +1,10 @@ using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray +# Necessary to make `BroadcastArray` work nicely with Zygote. +function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:ChainRulesCores.HasReverseMode}, ::Type{BroadcastArray}, f, args...) + return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...) +end + const LazyVectorOfUnivariate{ S<:ValueSupport, T<:UnivariateDistribution{S}, @@ -12,21 +17,30 @@ function Distributions._logpdf( dist::LazyVectorOfUnivariate, x::AbstractVector{<:Real}, ) - # TODO: Implement chain rule for `LazyArray` constructor to support Zygote. - f = make_closure(logpdf, _inner_constructor(typeof(dist.v))) # TODO: Make use of `sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)))` once # we've addressed performance issues in ReverseDiff.jl. - return sum(f.(x, dist.v.args...)) + constructor = _inner_constructor(typeof(dist.v)) + return if has_specalized_make_closure(logpdf, constructor) + f = make_closure(logpdf, constructor) + sum(f.(x, dist.v.args...)) + else + sum(copy(logpdf.(dist, x))) + end end function Distributions.logpdf( dist::LazyVectorOfUnivariate, x::AbstractMatrix{<:Real}, - ) +) size(x, 1) == length(dist) || throw(DimensionMismatch("Inconsistent array dimensions.")) - f = make_closure(logpdf, _inner_constructor(typeof(dist.v))) - return vec(sum(f.(x, dist.v.args...), dims = 1)) + constructor = _inner_constructor(typeof(dist.v)) + return if has_specialized_make_closure(logpdf, constructor) + f = make_closure(logpdf, constructor) + vec(sum(f.(x, dist.v.args...), dims = 1)) + else + vec(sum(copy(logpdf.(dist, x)); dims = 1)) + end end const LazyMatrixOfUnivariate{ @@ -39,15 +53,15 @@ function Distributions._logpdf( dist::LazyMatrixOfUnivariate, x::AbstractMatrix{<:Real}, ) - f = make_closure(logpdf, _inner_constructor(typeof(dist.v))) - - return sum(f.(x, dist.v.args)) + + constructor = _inner_constructor(typeof(dist.v)) + return if has_specialized_make_closure(logpdf, constructor) + f = make_closure(logpdf, constructor) + sum(f.(x, dist.v.args)) + else + sum(copy(logpdf.(dist.dists, x))) + end end lazyarray(f, x...) = BroadcastArray(f, x...) export lazyarray - -# Necessary to make `BroadcastArray` work nicely with Zygote. -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::Type{LazyArrays.BroadcastArray}, f, args...) - return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...) -end From 896bce3ad58c49c12d0062d815a48edbc67d4668 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 17 Jan 2023 14:00:21 +0000 Subject: [PATCH 04/17] Update src/lazyarrays.jl --- src/lazyarrays.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl index dc5e354..975d0d2 100644 --- a/src/lazyarrays.jl +++ b/src/lazyarrays.jl @@ -1,4 +1,5 @@ using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray +using ChainRulesCore: ChainRulesCore # Necessary to make `BroadcastArray` work nicely with Zygote. function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:ChainRulesCores.HasReverseMode}, ::Type{BroadcastArray}, f, args...) From fcf9bf45336235f4c39e412c47069b75852ab841 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 17 Jan 2023 14:07:33 +0000 Subject: [PATCH 05/17] fixed typo --- src/lazyarrays.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl index dc5e354..b85f56d 100644 --- a/src/lazyarrays.jl +++ b/src/lazyarrays.jl @@ -20,7 +20,7 @@ function Distributions._logpdf( # TODO: Make use of `sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)))` once # we've addressed performance issues in ReverseDiff.jl. constructor = _inner_constructor(typeof(dist.v)) - return if has_specalized_make_closure(logpdf, constructor) + return if has_specialized_make_closure(logpdf, constructor) f = make_closure(logpdf, constructor) sum(f.(x, dist.v.args...)) else From 74d4e38d0350f251c8687a2767f7616ec889a5ad Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 17 Jan 2023 14:12:02 +0000 Subject: [PATCH 06/17] Update src/lazyarrays.jl --- src/lazyarrays.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl index a353752..1e876cc 100644 --- a/src/lazyarrays.jl +++ b/src/lazyarrays.jl @@ -1,5 +1,5 @@ using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray -using ChainRulesCore: ChainRulesCore +using .ChainRulesCore: ChainRulesCore # Necessary to make `BroadcastArray` work nicely with Zygote. function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:ChainRulesCores.HasReverseMode}, ::Type{BroadcastArray}, f, args...) From ae52b81705d5416ea3436e774c93918fb947d22f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 17 Jan 2023 14:12:44 +0000 Subject: [PATCH 07/17] okay final fix for chainrules --- src/lazyarrays.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl index 1e876cc..e2ce573 100644 --- a/src/lazyarrays.jl +++ b/src/lazyarrays.jl @@ -1,8 +1,7 @@ using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray -using .ChainRulesCore: ChainRulesCore # Necessary to make `BroadcastArray` work nicely with Zygote. -function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:ChainRulesCores.HasReverseMode}, ::Type{BroadcastArray}, f, args...) +function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, ::Type{BroadcastArray}, f, args...) return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...) end From fcdd588b42b6846cf6a07c9cdbcf45e4ab355ed1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 17 Jan 2023 14:36:50 +0000 Subject: [PATCH 08/17] use zygoterules instead --- src/lazyarrays.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl index e2ce573..61120d3 100644 --- a/src/lazyarrays.jl +++ b/src/lazyarrays.jl @@ -1,8 +1,12 @@ using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray -# Necessary to make `BroadcastArray` work nicely with Zygote. -function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, ::Type{BroadcastArray}, f, args...) - return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...) +# # Necessary to make `BroadcastArray` work nicely with Zygote. +# function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, ::Type{BroadcastArray}, f, args...) +# return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...) +# end + +ZygoteRules.@adjoint function BroadcastArray(f, args...) + return ZygoteRules.pullback(Broadcast.broadcasted, f, args...) end const LazyVectorOfUnivariate{ From 59423de16ba395b853cfc65f0764ac6fa5e01231 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 17 Jan 2023 14:49:31 +0000 Subject: [PATCH 09/17] ZygoteRules didn't help --- src/lazyarrays.jl | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl index 61120d3..e2ce573 100644 --- a/src/lazyarrays.jl +++ b/src/lazyarrays.jl @@ -1,12 +1,8 @@ using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray -# # Necessary to make `BroadcastArray` work nicely with Zygote. -# function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, ::Type{BroadcastArray}, f, args...) -# return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...) -# end - -ZygoteRules.@adjoint function BroadcastArray(f, args...) - return ZygoteRules.pullback(Broadcast.broadcasted, f, args...) +# Necessary to make `BroadcastArray` work nicely with Zygote. +function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, ::Type{BroadcastArray}, f, args...) + return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...) end const LazyVectorOfUnivariate{ From bcfdecf4fba5876c56e78e84b51956a18b2e0012 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 17 Jan 2023 22:35:57 +0000 Subject: [PATCH 10/17] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 185b08a..6d91eef 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.43" +version = "0.7" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 90d3bbc53e4ce50b3af9f52af12e5113d441984c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 18 Jan 2023 03:17:39 +0000 Subject: [PATCH 11/17] replaced macros and closures with struct and generated --- src/common.jl | 85 ++++++++--------------------------------------- src/lazyarrays.jl | 21 ++---------- 2 files changed, 16 insertions(+), 90 deletions(-) diff --git a/src/common.jl b/src/common.jl index 15dda4f..26b73f2 100644 --- a/src/common.jl +++ b/src/common.jl @@ -49,10 +49,11 @@ parameterless_type(x::Type) = __parameterless_type(x) @non_differentiable adapt_randn(::Any...) + """ - make_closure(f, g) + Closure{F,G} -Return a closure of the form `(x, args...) -> f(g(args...), x)`. +A callable of the form `(x, args...) -> F(G(args...), x)`. # Examples @@ -74,9 +75,7 @@ julia> @btime ReverseDiff.gradient(\$f, \$x); 848.759 μs (14605 allocations: 521.84 KiB) julia> # Much faster with ReverseDiff.jl. - g(x) = let g_inner = DistributionsAD.make_closure(logpdf, Normal) - sum(g_inner.(data, x)) - end + g(x) = sum(DistributionsAD.Closure(logpd, Normal).(data, x)) g (generic function with 1 method) julia> @btime ReverseDiff.gradient(\$g, \$x); @@ -84,76 +83,18 @@ julia> @btime ReverseDiff.gradient(\$g, \$x); ``` See https://github.com/TuringLang/Turing.jl/issues/1934 more further discussion. - -# Notes -To really go "vrooom!\" one needs to specialize on the arguments, e.g. if one -has a function `myfunc` then we need to define - -```julia -make_closure(::typeof(myfunc), ::Type{D}) where {D} = myfunc(D(args...), x) -``` - -This can also be done using `DistributionsAD.@specialize_make_closure`: - -```julia -julia> mylogpdf(d, x) = logpdf(d, x) -mylogpdf (generic function with 1 method) - -julia> h(x) = let inner = DistributionsAD.make_closure(mylogpdf, Normal) - sum(inner.(data, x)) - end -h (generic function with 1 method) - -julia> @btime ReverseDiff.gradient(\$h, \$x); - 1.220 ms (37011 allocations: 1.42 MiB) - -julia> DistributionsAD.@specialize_make_closure mylogpdf - -julia> @btime ReverseDiff.gradient(\$h, \$x); - 17.038 μs (17 allocations: 71.52 KiB) -``` -""" -make_closure(f, g) = (x, args...) -> f(g(args...), x) -make_closure(f, ::Type{D}) where {D} = (x, args...) -> f(D(args...), x) - - -""" - has_specialized_make_closure(f, g) - -Return `true` if there exists a specialized `make_closure(f, g)` implementation. """ -has_specialized_make_closure(f, g) = false +struct Closure{F,G} end -# To go vroooom we need to specialize on the first argument, thus ensuring that -# a different closure is constructed for each method. -""" - @specialize_make_closure(f) +Closure(::F, ::G) where {F,G} = Closure{F,G}() +Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}() +Closure(::Type{F}, ::G) where {F,G} = Closure{F,G}() +Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{F,G}() -Define `make_closure` and `has_specialized_make_closure` for first first argument being `f` -and second argument being a type. -""" -macro specialize_make_closure(f) - return quote - $(DistributionsAD).make_closure(::typeof($(esc(f))), ::Type{D}) where {D} = (x, args...) -> $(esc(f))(D(args...), x) - $(DistributionsAD).has_specialized_make_closure(::typeof($(esc(f))), ::Type{D}) where {D} = true - end +@generated function (closure::Closure{F,G})(x, args...) where {F,G} + f = Base.issingletontype(F) ? F.instance : F + g = Base.issingletontype(G) ? G.instance : G + return :($f($g(args...), x)) end -""" - @specialize_make_closure(f, g) - -Define `make_closure` and `has_specialized_make_closure` for first first argument being `f` -and second argument being `g`. -""" -macro specialize_make_closure(f, g) - return quote - $(DistributionsAD).make_closure(::typeof($(esc(f))), ::typeof($(esc(g)))) = (x, args...) -> $(esc(f))($(esc(g))(args...), x) - $(DistributionsAD).has_specialized_make_closure(::typeof($(esc(f))), ::typeof{$(esc(g))}) = true - end -end -@specialize_make_closure Distributions.pdf -@specialize_make_closure Distributions.logpdf -@specialize_make_closure Distributions.loglikelihood -@specialize_make_closure Distributions.cdf -@specialize_make_closure Distributions.logcdf diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl index e2ce573..325ba0f 100644 --- a/src/lazyarrays.jl +++ b/src/lazyarrays.jl @@ -20,12 +20,7 @@ function Distributions._logpdf( # TODO: Make use of `sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)))` once # we've addressed performance issues in ReverseDiff.jl. constructor = _inner_constructor(typeof(dist.v)) - return if has_specialized_make_closure(logpdf, constructor) - f = make_closure(logpdf, constructor) - sum(f.(x, dist.v.args...)) - else - sum(copy(logpdf.(dist, x))) - end + return sum(Closure(logpdf, constructor).(x, dist.v.args...)) end function Distributions.logpdf( @@ -35,12 +30,7 @@ function Distributions.logpdf( size(x, 1) == length(dist) || throw(DimensionMismatch("Inconsistent array dimensions.")) constructor = _inner_constructor(typeof(dist.v)) - return if has_specialized_make_closure(logpdf, constructor) - f = make_closure(logpdf, constructor) - vec(sum(f.(x, dist.v.args...), dims = 1)) - else - vec(sum(copy(logpdf.(dist, x)); dims = 1)) - end + return vec(sum(Closure(logpdf, constructor).(x, dist.v.args...), dims = 1)) end const LazyMatrixOfUnivariate{ @@ -55,12 +45,7 @@ function Distributions._logpdf( ) constructor = _inner_constructor(typeof(dist.v)) - return if has_specialized_make_closure(logpdf, constructor) - f = make_closure(logpdf, constructor) - sum(f.(x, dist.v.args)) - else - sum(copy(logpdf.(dist.dists, x))) - end + return sum(Closure(logpdf, constructor).(x, dist.v.args)) end lazyarray(f, x...) = BroadcastArray(f, x...) From 9a1b20165a4b891f5269fe9c067191041a454f66 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 18 Jan 2023 03:17:54 +0000 Subject: [PATCH 12/17] make Zygote realize it should use ForwardDiff when broadcasting Closure --- src/DistributionsAD.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index 4770ea0..808b574 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -70,6 +70,12 @@ include("zygote.jl") end end + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + using .Zygote: Zygote + # HACK: Make Zygote (correctly) recognize that it should use `ForwardDiff` for broadcasting. + Zygote._dual_purefun(::Type{C}) where {C<:Closure} = Base.issingletontype(C) + end + @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin using DiffRules using SpecialFunctions From a792e8bca58402907011d314b149ee93720a7a93 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 18 Jan 2023 13:42:51 +0000 Subject: [PATCH 13/17] added custom adjoint for logpdf with broadcastarray and stuff --- Project.toml | 2 +- src/lazyarrays.jl | 47 ++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 6d91eef..37e58cd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.7" +version = "0.6.44" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl index 325ba0f..9f2e904 100644 --- a/src/lazyarrays.jl +++ b/src/lazyarrays.jl @@ -1,10 +1,5 @@ using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray -# Necessary to make `BroadcastArray` work nicely with Zygote. -function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, ::Type{BroadcastArray}, f, args...) - return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...) -end - const LazyVectorOfUnivariate{ S<:ValueSupport, T<:UnivariateDistribution{S}, @@ -50,3 +45,45 @@ end lazyarray(f, x...) = BroadcastArray(f, x...) export lazyarray + +# HACK: All of the below probably shouldn't be here. +function ChainRulesCore.rrule(::Type{BroadcastArray}, f, args...) + function BroadcastArray_pullback(Δ::ChainRulesCore.Tangent) + return (ChainRulesCore.NoTangent(), Δ.f, Δ.args...) + end + return BroadcastArray(f, args...), BroadcastArray_pullback +end + +ChainRulesCore.ProjectTo(ba::BroadcastArray) = ProjectTo{typeof(ba)}((f=ba.f,)) +function (p::ChainRulesCore.ProjectTo{BA})(args...) where {BA<:BroadcastArray} + return ChainRulesCore.Tangent{BA}(f=p.f, args=args) +end + +function ChainRulesCore.rrule( + config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, + ::typeof(Distributions.logpdf), + dist::LazyVectorOfUnivariate, + x::AbstractVector{<:Real} +) + cl = DistributionsAD.Closure(logpdf, DistributionsAD._inner_constructor(typeof(dist.v))) + y, dy = ChainRulesCore.rrule_via_ad(config, broadcast, cl, x, dist.v.args...) + z, dz = ChainRulesCore.rrule_via_ad(config, sum, y) + + project_broadcastarray = ChainRulesCore.ProjectTo(dist.v) + function logpdf_adjoint(Δ...) + # 1st argument is `sum` -> nothing. + (_, sum_Δ...) = dz(Δ...) + # 1st argument is `broadcast` -> nothing. + # 2nd argument is `cl` -> `nothing`. + # 3rd argument is `x` -> something. + # Rest is `dist` arguments -> something + (_, _, x_Δ, args_Δ...) = dy(sum_Δ...) + # Construct the structural tangents. + ba_tangent = project_broadcastarray(args_Δ...) + dist_tangent = ChainRulesCore.Tangent{typeof(dist)}(v=ba_tangent) + + return (ChainRulesCore.NoTangent(), dist_tangent, x_Δ) + end + + return z, logpdf_adjoint +end From c9324a3f0f0533ae0bf0ce431a0ef08b9369feee Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 18 Jan 2023 19:11:24 +0000 Subject: [PATCH 14/17] only use faster path for Zygote if we're working witha Closure that supports it --- src/DistributionsAD.jl | 3 ++- src/common.jl | 40 ++++++++++++++++++++++++++++++++++++++++ src/lazyarrays.jl | 12 +++++++++++- 3 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index 808b574..fad23d0 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -73,7 +73,8 @@ include("zygote.jl") @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin using .Zygote: Zygote # HACK: Make Zygote (correctly) recognize that it should use `ForwardDiff` for broadcasting. - Zygote._dual_purefun(::Type{C}) where {C<:Closure} = Base.issingletontype(C) + # See `is_diff_safe` for more information. + Zygote._dual_purefun(::Type{C}) where {C<:Closure} = is_diff_safe(C) end @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin diff --git a/src/common.jl b/src/common.jl index 26b73f2..1a6e950 100644 --- a/src/common.jl +++ b/src/common.jl @@ -91,6 +91,46 @@ Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}() Closure(::Type{F}, ::G) where {F,G} = Closure{F,G}() Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{F,G}() +""" + is_diff_safe(f) + +Return `true` if it's safe to ignore gradients wrt. `f` when computing `f`. + +Useful for checking it's okay to take faster paths in pullbacks for certain AD backends. + +# Examples + +```julia +julia> using Distributions + +julia> using DistributionsAD: is_diff_safe, Closure + +julia> is_diff_safe(typeof(logpdf)) +true + +julia> is_diff_safe(typeof(x -> 2x)) +true + +julia> # But it fails if we make a closure over a variable, which we might want to compute + # the gradient with respect to. + makef(x) = y -> x + y +makef (generic function with 1 method) + +julia> is_diff_safe(typeof(makef([1.0]))) +false + +julia> # Also works on `Closure`s from `DistributionsAD`. + is_diff_safe(typeof(Closure(logpdf, Normal))) +true + +julia> is_diff_safe(typeof(Closure(logpdf, makef([1.0])))) +false +""" +@inline is_diff_safe(_) = false +@inline is_diff_safe(::Type) = true +@inline is_diff_safe(::Type{F}) where {F<:Function} = Base.issingletontype(F) +@inline is_diff_safe(::Type{Closure{F,G}}) where {F,G} = is_diff_safe(F) && is_diff_safe(G) + @generated function (closure::Closure{F,G})(x, args...) where {F,G} f = Base.issingletontype(F) ? F.instance : F g = Base.issingletontype(G) ? G.instance : G diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl index 9f2e904..f76e207 100644 --- a/src/lazyarrays.jl +++ b/src/lazyarrays.jl @@ -65,7 +65,17 @@ function ChainRulesCore.rrule( dist::LazyVectorOfUnivariate, x::AbstractVector{<:Real} ) - cl = DistributionsAD.Closure(logpdf, DistributionsAD._inner_constructor(typeof(dist.v))) + # Extract the constructor used in the `BroadcastArray`. + constructor = DistributionsAD._inner_constructor(typeof(dist.v)) + + # If it's not safe to ignore the `constructor` in the pullback, then we fall back + # to the default implementation. + is_diff_safe(constructor) || return ChainRulesCore.rrule_via_ad(config, (d,x) -> sum(logpdf.(d.v, x)), dist, x) + + # Otherwise, we use `Closure`. + cl = DistributionsAD.Closure(logpdf, constructor) + + # Construct pullbacks manually to avoid the constructor of `BroadcastArray`. y, dy = ChainRulesCore.rrule_via_ad(config, broadcast, cl, x, dist.v.args...) z, dz = ChainRulesCore.rrule_via_ad(config, sum, y) From b3b27866ac611a754030ca5099a48053046e8679 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 18 Jan 2023 19:14:07 +0000 Subject: [PATCH 15/17] make a docstring doctest instead --- src/common.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common.jl b/src/common.jl index 1a6e950..012acf5 100644 --- a/src/common.jl +++ b/src/common.jl @@ -100,7 +100,7 @@ Useful for checking it's okay to take faster paths in pullbacks for certain AD b # Examples -```julia +```jldoctest julia> using Distributions julia> using DistributionsAD: is_diff_safe, Closure From a43d6e5380fd7dd411bbc29c04fd706c0990fdd0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 19 Jan 2023 00:05:37 +0000 Subject: [PATCH 16/17] Update src/common.jl Co-authored-by: David Widmann --- src/common.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common.jl b/src/common.jl index 012acf5..74d82f3 100644 --- a/src/common.jl +++ b/src/common.jl @@ -75,7 +75,7 @@ julia> @btime ReverseDiff.gradient(\$f, \$x); 848.759 μs (14605 allocations: 521.84 KiB) julia> # Much faster with ReverseDiff.jl. - g(x) = sum(DistributionsAD.Closure(logpd, Normal).(data, x)) + g(x) = sum(DistributionsAD.Closure(logpdf, Normal).(data, x)) g (generic function with 1 method) julia> @btime ReverseDiff.gradient(\$g, \$x); From 8604902e417b75edaa7052924b94d830ee0ef413 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 19 Jan 2023 00:13:43 +0000 Subject: [PATCH 17/17] Update src/lazyarrays.jl Co-authored-by: David Widmann --- src/lazyarrays.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl index f76e207..f6db1eb 100644 --- a/src/lazyarrays.jl +++ b/src/lazyarrays.jl @@ -61,7 +61,7 @@ end function ChainRulesCore.rrule( config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, - ::typeof(Distributions.logpdf), + ::typeof(logpdf), dist::LazyVectorOfUnivariate, x::AbstractVector{<:Real} )