From 839c6e04d631a015821b0068ce8654156841799d Mon Sep 17 00:00:00 2001 From: Lucas C Wilcox Date: Wed, 5 Nov 2025 10:49:17 -0800 Subject: [PATCH] Add MPI extension with Allreduce! forward rule Co-authored-by: Valentin Churavy --- Project.toml | 4 ++++ ext/EnzymeMPIExt.jl | 37 +++++++++++++++++++++++++++++ test/integration/MPI/collectives.jl | 26 ++++++++++++++++++++ test/integration/MPI/runtests.jl | 4 ++++ 4 files changed, 71 insertions(+) create mode 100644 ext/EnzymeMPIExt.jl create mode 100644 test/integration/MPI/collectives.jl diff --git a/Project.toml b/Project.toml index cd08bc56af..be4e4188a9 100644 --- a/Project.toml +++ b/Project.toml @@ -26,6 +26,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -35,6 +36,7 @@ EnzymeChainRulesCoreExt = "ChainRulesCore" EnzymeDynamicPPLExt = ["ADTypes", "DynamicPPL"] EnzymeGPUArraysCoreExt = "GPUArraysCore" EnzymeLogExpFunctionsExt = "LogExpFunctions" +EnzymeMPIExt = "MPI" EnzymeSpecialFunctionsExt = "SpecialFunctions" EnzymeStaticArraysExt = "StaticArrays" @@ -50,6 +52,7 @@ GPUArraysCore = "0.1.6, 0.2" GPUCompiler = "1.6.2" LLVM = "9.1" LogExpFunctions = "0.3" +MPI = "0.20" ObjectFile = "0.4, 0.5" PrecompileTools = "1" Preferences = "1.4" @@ -65,6 +68,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/ext/EnzymeMPIExt.jl b/ext/EnzymeMPIExt.jl new file mode 100644 index 0000000000..ae737d9b41 --- /dev/null +++ b/ext/EnzymeMPIExt.jl @@ -0,0 +1,37 @@ +module EnzymeMPIExt + +using MPI +using Enzyme + +import Enzyme.EnzymeCore: EnzymeRules + +function EnzymeRules.forward(config, ::Const{typeof(MPI.Allreduce!)}, rt, v, op::Const, comm::Const) + op = op.val + comm = comm.val + + if !(op == MPI.SUM || op == +) + error("Forward mode MPI.Allreduce! is only implemented for MPI.SUM.") + end + + if EnzymeRules.needs_primal(config) + MPI.Allreduce!(v.val, op, comm) + end + + if EnzymeRules.width(config) == 1 + MPI.Allreduce!(v.dval, op, comm) + else + # would be nice to use MPI non-blocking collectives + foreach(v.dval) do dval + MPI.Allreduce!(dval, op, comm) + end + end + + if EnzymeRules.needs_primal(config) + return v + else + return v.dval + end +end + + +end diff --git a/test/integration/MPI/collectives.jl b/test/integration/MPI/collectives.jl new file mode 100644 index 0000000000..afbbc054de --- /dev/null +++ b/test/integration/MPI/collectives.jl @@ -0,0 +1,26 @@ +using MPI +using Enzyme +using Test + +MPI.Init() + +@show Base.get_extension(Enzyme, :EnzymeMPIExt) + +buff = Ref(3.0) +comm = MPI.COMM_WORLD + +MPI.Allreduce!(buff, MPI.SUM, comm) + +@test buff[] == MPI.Comm_size(comm) * 3.0 + +buff[] = 3.0 +dbuff = Ref(0.0) + +if MPI.Comm_rank(comm) == 0 + dbuff[] = 1.0 +end + +autodiff(ForwardWithPrimal, MPI.Allreduce!, Duplicated(buff, dbuff), Const(MPI.SUM), Const(comm)) + +@test buff[] == MPI.Comm_size(comm) * 3.0 +@test dbuff[] == 1.0 diff --git a/test/integration/MPI/runtests.jl b/test/integration/MPI/runtests.jl index 706ad24fd5..4bce80bf6d 100644 --- a/test/integration/MPI/runtests.jl +++ b/test/integration/MPI/runtests.jl @@ -1,3 +1,7 @@ using MPI using Enzyme using Test + +@testset "collectives" for np in (1, 2, 4) + run(`$(mpiexec()) -n $np $(Base.julia_cmd()) --project=$(@__DIR__) $(joinpath(@__DIR__, "collectives.jl"))`) +end