Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ Base.convert(::Type{API.CDerivativeMode}, ::ForwardMode) = API.DEM_ForwardMode
function guess_activity end

mutable struct EnzymeContext
world::UInt64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should already be accessible everywhere gutils is available per enzyme_extract_world(fn) or something

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is meant as a proof of concept to demonstrate that yes we can get the world we are compiling from everywhere.

Using a string attribute is fine, but is also a blocker for caching the IR for later use, so I wanted to see if we can use this to avoid adding ephemeral data to potentially cacheable inputs.

end

include("logic.jl")
Expand Down
9 changes: 9 additions & 0 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,15 @@ EnzymeCloneFunctionWithoutReturnOrArgs(fn::LLVM.Function, keepret, args) = ccall
EnzymeGetShadowType(width, T) =
ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64, LLVMTypeRef), width, T)

function EnzymeGradientUtilsGetExternalContext(gutils)
ccall(
(:EnzymeGradientUtilsGetExternalContext, libEnzyme),
Ptr{Cvoid},
(EnzymeGradientUtilsRef,),
gutils,
)
end

EnzymeGradientUtilsReplaceAWithB(gutils, a, b) = ccall(
(:EnzymeGradientUtilsReplaceAWithB, libEnzyme),
Cvoid,
Expand Down
45 changes: 29 additions & 16 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,12 @@ include("compiler/utils.jl")

include("compiler/orcv2.jl")

include("gradientutils.jl")

import .Enzyme: GradientUtils, call_samefunc_with_inverted_bundles!,
get_width, get_mode, get_runtime_activity,
get_strong_zero, get_shadow_type, get_uncacheable,
erase_with_placeholder, is_constant_value, is_constant_inst,
new_from_original, lookup_value, invert_pointer, debug_from_orig!,
add_reverse_block!, set_reverse_block!, enzyme_context, enzyme_gutils_context

# Julia function to LLVM stem and arity
const cmplx_known_ops =
Expand Down Expand Up @@ -482,12 +486,13 @@ include("llvm/transforms.jl")
include("llvm/passes.jl")
include("typeutils/make_zero.jl")

function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt)
funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, world)
nested_codegen!(mode, mod, funcspec, world)
function nested_codegen!(ctx::EnzymeContext, mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type))
funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, ctx.world)
nested_codegen!(ctx, mode, mod, funcspec)
end

function prepare_llvm(interp, mod::LLVM.Module, job, meta)
# TODO: remove enzymejl_world
for f in functions(mod)
attributes = function_attributes(f)
push!(attributes, StringAttribute("enzymejl_world", string(job.world)))
Expand Down Expand Up @@ -625,7 +630,7 @@ end
name = meth.name
jlmod = meth.module

julia_activity_rule(llvmfn)
julia_activity_rule(llvmfn, world)
if has_custom_rule
handleCustom(
state,
Expand Down Expand Up @@ -1079,6 +1084,7 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLV
end

world = enzyme_extract_world(f)
@assert world == interp.world

if expectLen != length(parameters(f))
continue
Expand Down Expand Up @@ -1229,11 +1235,12 @@ const DumpPreNestedOpt = Ref(false)
const DumpPostNestedOpt = Ref(false)

function nested_codegen!(
ctx::EnzymeContext,
mode::API.CDerivativeMode,
mod::LLVM.Module,
funcspec::Core.MethodInstance,
world::UInt,
)
world = ctx.world
# TODO: Put a cache here index on `mod` and f->tt


Expand All @@ -1249,6 +1256,7 @@ function nested_codegen!(
GPUCompiler.prepare_job!(job)
otherMod, meta = GPUCompiler.emit_llvm(job)

# TODO: interp should be cached since it contains internal caches
interp = GPUCompiler.get_interpreter(job)
prepare_llvm(interp, otherMod, job, meta)

Expand Down Expand Up @@ -1664,6 +1672,7 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie
mode == API.DEM_ReverseModeCombined
fn = LLVM.parent(LLVM.parent(V))
world = enzyme_extract_world(fn)
@assert world == enzyme_context(gutils).world
if !guaranteed_nonactive(Ty, world)
B = LLVM.IRBuilder()
position!(B, V)
Expand Down Expand Up @@ -2392,6 +2401,7 @@ const DumpPostEnzyme = Ref(false)
const DumpPostWrap = Ref(false)

function enzyme!(
enzyme_context::EnzymeContext,
job::CompilerJob,
interp,
mod::LLVM.Module,
Expand Down Expand Up @@ -2507,7 +2517,6 @@ function enzyme!(
convert(API.CDIFFE_TYPE, rt)
end

enzyme_context = EnzymeContext()
GC.@preserve enzyme_context begin
LLVM.@dispose logic = Logic(enzyme_context) begin

Expand Down Expand Up @@ -2577,6 +2586,7 @@ function enzyme!(

if wrap
augmented_primalf = create_abi_wrapper(
enzyme_context,
augmented_primalf,
TT,
rt,
Expand All @@ -2586,7 +2596,6 @@ function enzyme!(
width,
returnPrimal,
shadow_init,
world,
interp,
runtimeActivity,
)
Expand Down Expand Up @@ -2619,6 +2628,7 @@ function enzyme!(
) #=atomicAdd=#
if wrap
adjointf = create_abi_wrapper(
enzyme_context,
adjointf,
TT,
rt,
Expand All @@ -2628,7 +2638,6 @@ function enzyme!(
width,
false,
shadow_init,
world,
interp,
runtimeActivity
) #=returnPrimal=#
Expand Down Expand Up @@ -2660,6 +2669,7 @@ function enzyme!(
augmented_primalf = nothing
if wrap
adjointf = create_abi_wrapper(
enzyme_context,
adjointf,
TT,
rt,
Expand All @@ -2669,7 +2679,6 @@ function enzyme!(
width,
returnPrimal,
shadow_init,
world,
interp,
runtimeActivity
)
Expand Down Expand Up @@ -2705,6 +2714,7 @@ function enzyme!(
if wrap
pf = adjointf
adjointf = create_abi_wrapper(
enzyme_context,
adjointf,
TT,
rt,
Expand All @@ -2714,7 +2724,6 @@ function enzyme!(
width,
returnPrimal,
shadow_init,
world,
interp,
runtimeActivity
)
Expand Down Expand Up @@ -2786,6 +2795,7 @@ function set_subprogram!(f::LLVM.Function, sp)
end

function create_abi_wrapper(
ctx::EnzymeContext,
enzymefn::LLVM.Function,
@nospecialize(TT::Type),
@nospecialize(rettype::Type),
Expand All @@ -2795,10 +2805,10 @@ function create_abi_wrapper(
width::Int,
returnPrimal::Bool,
shadow_init::Bool,
world::UInt,
interp,
runtime_activity::Bool
)
world = ctx.world
is_adjoint = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModeCombined
is_split = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModePrimal
needs_tape = Mode == API.DEM_ReverseModeGradient
Expand Down Expand Up @@ -3081,6 +3091,7 @@ function create_abi_wrapper(
realparms = LLVM.Value[]
i = 1

# TODO(vchuravy): remove
for attr in collect(function_attributes(enzymefn))
if kind(attr) == "enzymejl_world"
push!(function_attributes(llvm_f), attr)
Expand Down Expand Up @@ -3225,7 +3236,7 @@ function create_abi_wrapper(
elseif T <: BatchDuplicatedFunc
Func = get_func(T)
funcspec = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, Func, Tuple{}, world)
llvmf = nested_codegen!(Mode, mod, funcspec, world)
llvmf = nested_codegen!(ctx, Mode, mod, funcspec)
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))
Func_RT = return_type(interp, funcspec)
@assert Func_RT == NTuple{width,T′}
Expand Down Expand Up @@ -5096,15 +5107,17 @@ end
end
end

ctx = EnzymeContext(job.world)
if params.run_enzyme
# Generate the adjoint
memcpy_alloca_to_loadstore(mod)
force_recompute!(mod)
API.EnzymeDetectReadonlyOrThrow(mod)

adjointf, augmented_primalf, TapeType = enzyme!(
ctx,
job,
interp,
interp,
mod,
primalf,
TT,
Expand Down Expand Up @@ -5203,7 +5216,7 @@ end
fname = String(name) * pf
if haskey(functions(mod), fname)
funcspec = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, fnty, Tuple{JT}, job.world)
llvmf = nested_codegen!(mode, mod, funcspec, job.world)
llvmf = nested_codegen!(ctx, mode, mod, funcspec)
push!(function_attributes(llvmf), StringAttribute("implements", fname))
end
end
Expand Down
9 changes: 8 additions & 1 deletion src/errors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ function julia_error(
) #=error=#
world = enzyme_extract_world(f)
end
# TODO: get world from TypeAnalyzer
# @assert world == enzyme_gutils_context(gutils).world
throw(IllegalTypeAnalysisException(msg, mi, world, sval, ir, bt))
Comment on lines +478 to 480
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wsmoses would you be okay with a change where CustomErrorHandler always receives gutils? We can go from gutils to TypeAnalyzer (if I saw that right).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on the particular error thrown there may not be a gutils created.

But we can for sure audit all the current error types and see what we can do

elseif errtype == API.ET_NoType
@assert B != C_NULL
Expand Down Expand Up @@ -550,6 +552,7 @@ function julia_error(
illegal = false
created = LLVM.Instruction[]
world = enzyme_extract_world(LLVM.parent(position(IRBuilder(B))))
@assert world == enzyme_context(gutils).world
width = get_width(gutils)
function make_batched(@nospecialize(cur::LLVM.Value), B::LLVM.IRBuilder)::LLVM.Value
if width == 1
Expand Down Expand Up @@ -944,7 +947,7 @@ end
end
end

mi = nothing
mi = nothing
world = nothing

if isa(val, LLVM.Instruction)
Expand All @@ -962,6 +965,10 @@ end
) #=error=#
world = enzyme_extract_world(f)
end
# TODO(vchuravy)
# what is data?
# Can we get world here?
# @assert world == enzyme_context(gutils).world
mode = Enzyme.API.DEM_ReverseModeCombined

if mi !== nothing
Expand Down
15 changes: 14 additions & 1 deletion src/gradientutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ erase_with_placeholder(
orig::LLVM.Instruction,
erase::Bool = true,
) = API.EnzymeGradientUtilsEraseWithPlaceholder(gutils, inst, orig, erase)

is_constant_value(gutils::GradientUtils, val::LLVM.Value) =
API.EnzymeGradientUtilsIsConstantValue(gutils, val) != 0

Expand Down Expand Up @@ -96,4 +97,16 @@ end

function set_reverse_block!(gutils::GradientUtils, block::LLVM.BasicBlock)
return LLVM.BasicBlock(API.EnzymeGradientUtilsSetReverseBlock(gutils, block))
end
end

function enzyme_context(gutils::GradientUtils)
ptr = API.EnzymeGradientUtilsGetExternalContext(gutils)
@assert ptr != C_NULL
return unsafe_pointer_to_objref(ptr)::EnzymeContext
end

function enzyme_gutils_context(gutils::API.EnzymeGradientUtilsRef)
ptr = API.EnzymeGradientUtilsGetExternalContext(gutils)
@assert ptr != C_NULL
return unsafe_pointer_to_objref(ptr)::EnzymeContext
end
2 changes: 1 addition & 1 deletion src/logic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function enzyme_context(logic::Logic)
return logic.ctx::EnzymeContext
end

function enzyme_context(logic::API.EnzymeLogicRef)
function enzyme_logic_context(logic::API.EnzymeLogicRef)
ptr = API.LogicGetExternalContext(logic)
@assert ptr != C_NULL
return unsafe_pointer_to_objref(ptr)::EnzymeContext
Expand Down
3 changes: 1 addition & 2 deletions src/rules/activityrules.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

function julia_activity_rule(f::LLVM.Function)
function julia_activity_rule(f::LLVM.Function, world)
if startswith(LLVM.name(f), "japi3") || startswith(LLVM.name(f), "japi1")
return
end
Expand Down Expand Up @@ -30,7 +30,6 @@ function julia_activity_rule(f::LLVM.Function)
if mi.specTypes.parameters[end] === Vararg{Any}
return
end
world = enzyme_extract_world(f)

# TODO fix the attributor inlining such that this can assert always true
if expectLen != length(parameters(f))
Expand Down
Loading
Loading