Skip to content

feat: support forward-mode Mooncake with AutoMooncakeForward #813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 13, 2025
6 changes: 3 additions & 3 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.7.4"
version = "0.7.5"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -56,7 +56,7 @@ DifferentiationInterfaceTrackerExt = "Tracker"
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]

[compat]
ADTypes = "1.13.0"
ADTypes = "1.17.0"
Aqua = "0.8.12"
ChainRulesCore = "1.23.0"
ComponentArrays = "0.15.27"
Expand All @@ -77,7 +77,7 @@ JET = "0.9"
JLArrays = "0.2.0"
JuliaFormatter = "1,2"
LinearAlgebra = "1"
Mooncake = "0.4.122"
Mooncake = "0.4.147"
Pkg = "1"
PolyesterForwardDiff = "0.1.2"
Random = "1"
Expand Down
6 changes: 4 additions & 2 deletions DifferentiationInterface/docs/src/explanation/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ We support the following dense backend choices from [ADTypes.jl](https://github.
- [`AutoFiniteDifferences`](@extref ADTypes.AutoFiniteDifferences)
- [`AutoForwardDiff`](@extref ADTypes.AutoForwardDiff)
- [`AutoGTPSA`](@extref ADTypes.AutoGTPSA)
- [`AutoMooncake`](@extref ADTypes.AutoMooncake)
- [`AutoMooncake`](@extref ADTypes.AutoMooncake) and [`AutoMooncakeForward`](@extref ADTypes.AutoMooncake) (the latter is experimental)
- [`AutoPolyesterForwardDiff`](@extref ADTypes.AutoPolyesterForwardDiff)
- [`AutoReverseDiff`](@extref ADTypes.AutoReverseDiff)
- [`AutoSymbolics`](@extref ADTypes.AutoSymbolics)
Expand Down Expand Up @@ -48,6 +48,7 @@ In practice, many AD backends have custom implementations for high-level operato
| `AutoForwardDiff` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| `AutoGTPSA` | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
| `AutoMooncake` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| `AutoMooncakeForward` | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| `AutoPolyesterForwardDiff` | 🔀 | ❌ | 🔀 | ✅ | ✅ | 🔀 | 🔀 | 🔀 |
| `AutoReverseDiff` | ❌ | 🔀 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| `AutoSymbolics` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
Expand All @@ -68,6 +69,7 @@ Moreover, each context type is supported by a specific subset of backends:
| `AutoForwardDiff` |||
| `AutoGTPSA` |||
| `AutoMooncake` |||
| `AutoMooncakeForward` |||
| `AutoPolyesterForwardDiff` |||
| `AutoReverseDiff` |||
| `AutoSymbolics` |||
Expand Down Expand Up @@ -95,7 +97,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel
The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends.
It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use.
In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself.
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)).
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), reverse-mode [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)).

## Implementations

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
module DifferentiationInterfaceMooncakeExt

using ADTypes: ADTypes, AutoMooncake
using ADTypes: ADTypes, AutoMooncake, AutoMooncakeForward
import DifferentiationInterface as DI
using Mooncake:
Mooncake,
CoDual,
Config,
Dual,
prepare_derivative_cache,
prepare_gradient_cache,
prepare_pullback_cache,
primal,
tangent,
tangent_type,
value_and_derivative!!,
value_and_gradient!!,
value_and_pullback!!,
zero_dual,
zero_tangent,
rdata_type,
fdata,
Expand All @@ -25,17 +31,17 @@ using Mooncake:
_copy_output,
_copy_to_output!!

DI.check_available(::AutoMooncake) = true
const AnyAutoMooncake{C} = Union{AutoMooncake{C},AutoMooncakeForward{C}}

get_config(::AutoMooncake{Nothing}) = Config()
get_config(backend::AutoMooncake{<:Config}) = backend.config
DI.check_available(::AnyAutoMooncake{C}) where {C} = true

# tangents need to be copied before returning, otherwise they are still aliased in the cache
mycopy(x::Union{Number,AbstractArray{<:Number}}) = copy(x)
mycopy(x) = deepcopy(x)
get_config(::AnyAutoMooncake{Nothing}) = Config()
get_config(backend::AnyAutoMooncake{<:Config}) = backend.config

include("onearg.jl")
include("twoarg.jl")
include("forward_onearg.jl")
include("forward_twoarg.jl")
include("differentiate_with.jl")

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
## Pushforward

struct MooncakeOneArgPushforwardPrep{SIG,Tcache,DX} <: DI.PushforwardPrep{SIG}
_sig::Val{SIG}
cache::Tcache
dx_righttype::DX
end

function DI.prepare_pushforward_nokwarg(
strict::Val,
f::F,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C}
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
config = get_config(backend)
cache = prepare_derivative_cache(
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
)
dx_righttype = zero_tangent(x)
prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype)
return prep
end

function DI.value_and_pushforward(
f::F,
prep::MooncakeOneArgPushforwardPrep,
backend::AutoMooncakeForward,
x::X,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C,X}
DI.check_prep(f, prep, backend, x, tx, contexts...)
ys_and_ty = map(tx) do dx
dx_righttype =
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
y_dual = value_and_derivative!!(
prep.cache,
zero_dual(f),
Dual(x, dx_righttype),
map(zero_dual DI.unwrap, contexts)...,
)
y = primal(y_dual)
dy = _copy_output(tangent(y_dual))
return y, dy
end
y = first(ys_and_ty[1])
ty = last.(ys_and_ty)
return y, ty
end

function DI.pushforward(
f::F,
prep::MooncakeOneArgPushforwardPrep,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
return DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)[2]
end

function DI.value_and_pushforward!(
f::F,
ty::NTuple,
prep::MooncakeOneArgPushforwardPrep,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)
foreach(copyto!, ty, new_ty)
return y, ty
end

function DI.pushforward!(
f::F,
ty::NTuple,
prep::MooncakeOneArgPushforwardPrep,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
DI.value_and_pushforward!(f, ty, prep, backend, x, tx, contexts...)
return ty
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
## Pushforward

struct MooncakeTwoArgPushforwardPrep{SIG,Tcache,DX,DY} <: DI.PushforwardPrep{SIG}
_sig::Val{SIG}
cache::Tcache
dx_righttype::DX
dy_righttype::DY
end

function DI.prepare_pushforward_nokwarg(
strict::Val,
f!::F,
y,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C}
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
config = get_config(backend)
cache = prepare_derivative_cache(
f!,
y,
x,
map(DI.unwrap, contexts)...;
config.debug_mode,
config.silence_debug_messages,
)
dx_righttype = zero_tangent(x)
dy_righttype = zero_tangent(y)
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype)
return prep
end

function DI.value_and_pushforward(
f!::F,
y,
prep::MooncakeTwoArgPushforwardPrep,
backend::AutoMooncakeForward,
x::X,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C,X}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
ty = map(tx) do dx
dx_righttype =
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
y_dual = zero_dual(y)
value_and_derivative!!(
prep.cache,
zero_dual(f!),
y_dual,
Dual(x, dx_righttype),
map(zero_dual DI.unwrap, contexts)...,
)
dy = _copy_output(tangent(y_dual))
return dy
end
return y, ty
end

function DI.pushforward(
f!::F,
y,
prep::MooncakeTwoArgPushforwardPrep,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
return DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)[2]
end

function DI.value_and_pushforward!(
f!::F,
y::Y,
ty::NTuple,
prep::MooncakeTwoArgPushforwardPrep,
backend::AutoMooncakeForward,
x::X,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C,X,Y}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
foreach(tx, ty) do dx, dy
dx_righttype =
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
dy_righttype =
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
value_and_derivative!!(
prep.cache,
zero_dual(f!),
Dual(y, dy_righttype),
Dual(x, dx_righttype),
map(zero_dual DI.unwrap, contexts)...,
)
dy === dy_righttype || copyto!(dy, dy_righttype)
end
return y, ty
end

function DI.pushforward!(
f!::F,
y,
ty::NTuple,
prep::MooncakeTwoArgPushforwardPrep,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
DI.value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...)
return ty
end
2 changes: 2 additions & 0 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ using ADTypes:
AutoForwardDiff,
AutoGTPSA,
AutoMooncake,
AutoMooncakeForward,
AutoPolyesterForwardDiff,
AutoReverseDiff,
AutoSymbolics,
Expand Down Expand Up @@ -115,6 +116,7 @@ export AutoFiniteDifferences
export AutoForwardDiff
export AutoGTPSA
export AutoMooncake
export AutoMooncakeForward
export AutoPolyesterForwardDiff
export AutoReverseDiff
export AutoSymbolics
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/src/misc/differentiate_with.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be
!!! warning
`DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments.
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake](https://github.com/chalk-lab/Mooncake.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), reverse-mode [Mooncake](https://github.com/chalk-lab/Mooncake.jl), or if it automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper).
!!! warning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ end;
@testset for scen in filter(differentiatewith_scenarios()) do scen
DIT.operator(scen) == :pullback
end
Mooncake.TestUtils.test_rule(StableRNG(0), scen.f, scen.x; is_primitive=true)
Mooncake.TestUtils.test_rule(
StableRNG(0), scen.f, scen.x; is_primitive=true, mode=Mooncake.ReverseMode
)
end
end;

Expand Down
6 changes: 5 additions & 1 deletion DifferentiationInterface/test/Back/Mooncake/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ check_no_implicit_imports(DifferentiationInterface)

LOGGING = get(ENV, "CI", "false") == "false"

backends = [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())]
backends = [
AutoMooncake(; config=nothing),
AutoMooncake(; config=Mooncake.Config()),
AutoMooncakeForward(; config=nothing),
]

for backend in backends
@test check_available(backend)
Expand Down