From 36656ad57b72ecc573e8678f3ca440818ce6e7c1 Mon Sep 17 00:00:00 2001 From: Yousof Mardoukhi Date: Fri, 24 Oct 2025 23:59:09 +0200 Subject: [PATCH 1/8] fix: `LVM` -> `LLVM`. Fixed the typo. --- src/compiler.jl | 1485 +++++++++++++++++++++++------------------------ 1 file changed, 739 insertions(+), 746 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7bcee309a1..7d918cae0c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -108,9 +108,9 @@ struct PrimalCompilerParams <: AbstractEnzymeCompilerParams end function EnzymeCompilerParams(TT, mode, width, rt, run_enzyme, abiwrap, - modifiedBetween, returnPrimal, shadowInit, - expectedTapeType, ABI, - err_if_func_written, runtimeActivity, strongZero) + modifiedBetween, returnPrimal, shadowInit, + expectedTapeType, ABI, + err_if_func_written, runtimeActivity, strongZero) params = PrimalCompilerParams(mode) EnzymeCompilerParams( params, @@ -132,7 +132,7 @@ function EnzymeCompilerParams(TT, mode, width, rt, run_enzyme, abiwrap, end DefaultCompilerTarget(; kwargs...) = - GPUCompiler.NativeCompilerTarget(; jlruntime = true, kwargs...) + GPUCompiler.NativeCompilerTarget(; jlruntime=true, kwargs...) # TODO: Audit uses function EnzymeTarget() @@ -157,15 +157,15 @@ if VERSION >= v"1.11.0-DEV.1552" always_inline::Any method_table::Core.MethodTable param_type::Type - last_fwd_rule_world::Union{Nothing, Tuple} - last_rev_rule_world::Union{Nothing, Tuple} - last_ina_rule_world::Union{Nothing, Tuple} + last_fwd_rule_world::Union{Nothing,Tuple} + last_rev_rule_world::Union{Nothing,Tuple} + last_ina_rule_world::Union{Nothing,Tuple} end @inline EnzymeCacheToken(target_type::Type, always_inline::Any, method_table::Core.MethodTable, param_type::Type, world::UInt, is_forward::Bool, is_reverse::Bool, inactive_rule::Bool) = EnzymeCacheToken(target_type, always_inline, method_table, param_type, - is_forward ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.forward, Tuple{<:EnzymeCore.EnzymeRules.FwdConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world)...,) : nothing, - is_reverse ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.augmented_primal, Tuple{<:EnzymeCore.EnzymeRules.RevConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world)...,) : nothing, + is_forward ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.forward, Tuple{<:EnzymeCore.EnzymeRules.FwdConfig,<:Annotation,Type{<:Annotation},Vararg{Annotation}}, world)...,) : nothing, + is_reverse ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.augmented_primal, Tuple{<:EnzymeCore.EnzymeRules.RevConfig,<:Annotation,Type{<:Annotation},Vararg{Annotation}}, world)...,) : nothing, inactive_rule ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.inactive, Tuple{Vararg{Any}}, world)...,) : nothing ) @@ -279,7 +279,7 @@ const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,Data if (T isa Type) T = T::Type legal = T ∈ Tys - + if legal if name == :ldexp if !(sparam_vals[2] <: Integer) @@ -318,7 +318,7 @@ const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,Data if (T isa Type) T = T::Type legal = T ∈ Tys - + if legal if !all(==(T), sparam_vals) legal = false @@ -513,9 +513,6 @@ function prepare_llvm(interp, mod::LLVM.Module, job, meta) attributes, StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(RT)))), ) - if EnzymeRules.has_easy_rule_from_sig(Interpreter.simplify_kw(mi.specTypes); job.world) - push!(attributes, LLVM.StringAttribute("enzyme_LocalReadOnlyOrThrow")) - end if returnRoots attr = StringAttribute("enzymejl_returnRoots", "") push!(parameter_attributes(llvmfn, 2), attr) @@ -535,11 +532,11 @@ include("typeutils/inference.jl") import .Interpreter: isKWCallSignature -const mod_to_edges = Dict{LLVM.Module, Vector{Any}}() +const mod_to_edges = Dict{LLVM.Module,Vector{Any}}() mutable struct HandlerState - primalf::Union{Nothing, LLVM.Function} + primalf::Union{Nothing,LLVM.Function} must_wrap::Bool - actualRetType::Union{Nothing, Type} + actualRetType::Union{Nothing,Type} lowerConvention::Bool loweredArgs::Set{Int} boxedArgs::Set{Int} @@ -547,7 +544,7 @@ mutable struct HandlerState end -function handleCustom(state::HandlerState, custom, k_name::String, llvmfn::LLVM.Function, name::String, attrs::Vector{LLVM.Attribute} = LLVM.Attribute[], setlink::Bool = true, noinl::Bool = true) +function handleCustom(state::HandlerState, custom, k_name::String, llvmfn::LLVM.Function, name::String, attrs::Vector{LLVM.Attribute}=LLVM.Attribute[], setlink::Bool=true, noinl::Bool=true) attributes = function_attributes(llvmfn) custom[k_name] = linkage(llvmfn) if setlink @@ -564,7 +561,7 @@ function handleCustom(state::HandlerState, custom, k_name::String, llvmfn::LLVM. nothing end -function handle_compiled(state::HandlerState, edges::Vector, run_enzyme::Bool, mode::API.CDerivativeMode, world::UInt, method_table, custom::Dict{String, LLVM.API.LLVMLinkage}, mod::LLVM.Module, mi::Core.MethodInstance, k_name::String, @nospecialize(rettype::Type))::Nothing +function handle_compiled(state::HandlerState, edges::Vector, run_enzyme::Bool, mode::API.CDerivativeMode, world::UInt, method_table, custom::Dict{String,LLVM.API.LLVMLinkage}, mod::LLVM.Module, mi::Core.MethodInstance, k_name::String, @nospecialize(rettype::Type))::Nothing has_custom_rule = false specTypes = Interpreter.simplify_kw(mi.specTypes) @@ -613,13 +610,13 @@ function handle_compiled(state::HandlerState, edges::Vector, run_enzyme::Bool, m func = mi.specTypes.parameters[1] -@static if VERSION < v"1.11-" -else - if func == typeof(Core.memoryref) - attributes = function_attributes(llvmfn) - push!(attributes, EnumAttribute("alwaysinline", 0)) + @static if VERSION < v"1.11-" + else + if func == typeof(Core.memoryref) + attributes = function_attributes(llvmfn) + push!(attributes, EnumAttribute("alwaysinline", 0)) + end end -end meth = mi.def name = meth.name @@ -1011,7 +1008,6 @@ end Duplicated, nothing, run_enzyme, - world ) if cur state.primalf = llvmfn @@ -1028,23 +1024,23 @@ end attrs = if LLVM.version().major <= 15 LLVM.Attribute[LLVM.EnumAttribute("readnone"), StringAttribute("enzyme_shouldrecompute"), - EnumAttribute("willreturn"), - EnumAttribute("nosync"), - EnumAttribute("nounwind"), - EnumAttribute("nofree"), - ] + EnumAttribute("willreturn"), + EnumAttribute("nosync"), + EnumAttribute("nounwind"), + EnumAttribute("nofree"), + ] else LLVM.Attribute[EnumAttribute("memory", NoEffects.data), StringAttribute("enzyme_shouldrecompute"), - EnumAttribute("willreturn"), - EnumAttribute("nosync"), - EnumAttribute("nounwind"), - EnumAttribute("nofree")] + EnumAttribute("willreturn"), + EnumAttribute("nosync"), + EnumAttribute("nounwind"), + EnumAttribute("nofree")] end handleCustom(state, custom, k_name, llvmfn, name, attrs) return end -function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLVM.Function}, job, edges, run_enzyme, mode::API.CDerivativeMode) +function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing,LLVM.Function}, job, edges, run_enzyme, mode::API.CDerivativeMode) for f in functions(mod) mi, RT = enzyme_custom_extract_mi(f, false) @@ -1188,12 +1184,12 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLV state = HandlerState( primalf, - #=mustwrap=#false, - #=actualRetType=#nothing, - #=lowerConvention=#true, - #=loweredArgs=#Set{Int}(), - #=boxedArgs=#Set{Int}(), - #=fnsToInject=#Tuple{Symbol,Type}[], + false, #=mustwrap=# + nothing, #=actualRetType=# + true, #=lowerConvention=# + Set{Int}(), #=loweredArgs=# + Set{Int}(), #=boxedArgs=# + Tuple{Symbol,Type}[], #=fnsToInject=# ) for fname in LLVM.name.(functions(mod)) @@ -1244,11 +1240,11 @@ function nested_codegen!( target = DefaultCompilerTarget() params = PrimalCompilerParams(mode) - job = CompilerJob(funcspec, CompilerConfig(target, params; kernel = false, libraries = true, toplevel = true, optimize = false, cleanup = false, only_entry = false, validate = false), world) + job = CompilerJob(funcspec, CompilerConfig(target, params; kernel=false, libraries=true, toplevel=true, optimize=false, cleanup=false, only_entry=false, validate=false), world) GPUCompiler.prepare_job!(job) otherMod, meta = GPUCompiler.emit_llvm(job) - + interp = GPUCompiler.get_interpreter(job) prepare_llvm(interp, otherMod, job, meta) @@ -1267,15 +1263,15 @@ function nested_codegen!( API.AddPreserveNVVMPass!(pm, true) #=Begin=# LLVM.run!(pm, otherMod) end - + if DumpPreNestedCheck[] - API.EnzymeDumpModuleRef(otherMod.ref) + API.EnzymeDumpModuleRef(otherMod.ref) end check_ir(interp, job, otherMod) - + if DumpPreNestedOpt[] - API.EnzymeDumpModuleRef(otherMod.ref) + API.EnzymeDumpModuleRef(otherMod.ref) end # Skipped inline of blas @@ -1285,11 +1281,11 @@ function nested_codegen!( # Apply first stage of optimization's so that this module is at the same stage as `mod` optimize!(otherMod, JIT.get_tm()) - + if DumpPostNestedOpt[] - API.EnzymeDumpModuleRef(otherMod.ref) + API.EnzymeDumpModuleRef(otherMod.ref) end - + # 4) Link the corresponding module LLVM.link!(mod, otherMod) # 5) Call the function @@ -1430,7 +1426,7 @@ function julia_post_cache_store( end p = pn - vals = get_julia_inner_types(B, p, v, added = added) + vals = get_julia_inner_types(B, p, v, added=added) r = emit_writebarrier!(B, vals) @assert isa(r, LLVM.Instruction) push!(added, r.ref) @@ -1494,29 +1490,29 @@ function julia_undef_value_for_type( end # If count is nothing, it represents that we have an allocation of one of `Ty`. If it is a tuple LLVM values, it represents {the total size in bytes, the aligned size of each element} -function create_recursive_stores(B::LLVM.IRBuilder, @nospecialize(Ty::DataType), @nospecialize(prev::LLVM.Value), @nospecialize(count::Union{Nothing, Tuple{LLVM.Value, LLVM.ConstantInt}}))::Nothing +function create_recursive_stores(B::LLVM.IRBuilder, @nospecialize(Ty::DataType), @nospecialize(prev::LLVM.Value), @nospecialize(count::Union{Nothing,Tuple{LLVM.Value,LLVM.ConstantInt}}))::Nothing if Base.datatype_pointerfree(Ty) return end isboxed_ref = Ref{Bool}() LLVMType = LLVM.LLVMType(ccall(:jl_type_to_llvm, LLVM.API.LLVMTypeRef, - (Any, LLVM.Context, Ptr{Bool}), Ty, LLVM.context(), isboxed_ref)) + (Any, LLVM.Context, Ptr{Bool}), Ty, LLVM.context(), isboxed_ref)) if !isboxed_ref[] zeroAll = false prev = bitcast!(B, prev, LLVM.PointerType(LLVMType, addrspace(value_type(prev)))) prev = addrspacecast!(B, prev, LLVM.PointerType(LLVMType, Derived)) - atomic = true - if count === nothing - T_int64 = LLVM.Int64Type() + atomic = true + if count === nothing + T_int64 = LLVM.Int64Type() zero_single_allocation(B, Ty, LLVMType, prev, zeroAll, LLVM.ConstantInt(T_int64, 0); atomic) - nothing - else - (Size, AlignedSize) = count - zero_allocation(B, Ty, LLVMType, prev, AlignedSize, Size, zeroAll, atomic) - nothing - end + nothing + else + (Size, AlignedSize) = count + zero_allocation(B, Ty, LLVMType, prev, AlignedSize, Size, zeroAll, atomic) + nothing + end else if fieldcount(Ty) == 0 error("Error handling recursive stores for $Ty which has a fieldcount of 0") @@ -1527,64 +1523,64 @@ function create_recursive_stores(B::LLVM.IRBuilder, @nospecialize(Ty::DataType), T_int8 = LLVM.Int8Type() T_int64 = LLVM.Int64Type() - + T_pint8 = LLVM.PointerType(T_int8) prev2 = bitcast!(B, prev, LLVM.PointerType(T_int8, addrspace(value_type(prev)))) typedesc = Base.DataTypeFieldDesc(Ty) - needs_fullzero = false - if count !== nothing - for i in 1:fieldcount(Ty) - Ty2 = fieldtype(Ty, i) - off = fieldoffset(Ty, i) - - if typedesc[i].isptr || !(off == 0 && Base.aligned_sizeof(Ty) == Base.aligned_sizeof(Ty2)) - needs_fullzero = true - break - end - end - end - - if needs_fullzero - zeroAll = false - prev = bitcast!(B, prev, LLVM.PointerType(LLVMType, addrspace(value_type(prev)))) - prev = addrspacecast!(B, prev, LLVM.PointerType(LLVMType, Derived)) - atomic = true - (Size, AlignedSize) = count - zero_allocation(B, Ty, LLVMType, prev, AlignedSize, Size, zeroAll, atomic) - nothing - else - for i in 1:fieldcount(Ty) - Ty2 = fieldtype(Ty, i) - off = fieldoffset(Ty, i) - - prev3 = inbounds_gep!( - B, - T_int8, - prev2, - LLVM.Value[LLVM.ConstantInt(Int64(off))], - ) - - if typedesc[i].isptr - @assert count === nothing - Ty2 = Any - zeroAll = false - prev3 = bitcast!(B, prev3, LLVM.PointerType(T_prjlvalue, addrspace(value_type(prev3)))) - if addrspace(value_type(prev3)) != Derived - prev3 = addrspacecast!(B, prev3, LLVM.PointerType(T_prjlvalue, Derived)) - end - zero_single_allocation(B, Ty2, T_prjlvalue, prev3, zeroAll, LLVM.ConstantInt(T_int64, 0); atomic=true) - else - if count !== nothing - @assert off == 0 - @assert Base.aligned_sizeof(Ty) == Base.aligned_sizeof(Ty2) - end - create_recursive_stores(B, Ty2, prev3, count) - end - end - nothing - end + needs_fullzero = false + if count !== nothing + for i in 1:fieldcount(Ty) + Ty2 = fieldtype(Ty, i) + off = fieldoffset(Ty, i) + + if typedesc[i].isptr || !(off == 0 && Base.aligned_sizeof(Ty) == Base.aligned_sizeof(Ty2)) + needs_fullzero = true + break + end + end + end + + if needs_fullzero + zeroAll = false + prev = bitcast!(B, prev, LLVM.PointerType(LLVMType, addrspace(value_type(prev)))) + prev = addrspacecast!(B, prev, LLVM.PointerType(LLVMType, Derived)) + atomic = true + (Size, AlignedSize) = count + zero_allocation(B, Ty, LLVMType, prev, AlignedSize, Size, zeroAll, atomic) + nothing + else + for i in 1:fieldcount(Ty) + Ty2 = fieldtype(Ty, i) + off = fieldoffset(Ty, i) + + prev3 = inbounds_gep!( + B, + T_int8, + prev2, + LLVM.Value[LLVM.ConstantInt(Int64(off))], + ) + + if typedesc[i].isptr + @assert count === nothing + Ty2 = Any + zeroAll = false + prev3 = bitcast!(B, prev3, LLVM.PointerType(T_prjlvalue, addrspace(value_type(prev3)))) + if addrspace(value_type(prev3)) != Derived + prev3 = addrspacecast!(B, prev3, LLVM.PointerType(T_prjlvalue, Derived)) + end + zero_single_allocation(B, Ty2, T_prjlvalue, prev3, zeroAll, LLVM.ConstantInt(T_int64, 0); atomic=true) + else + if count !== nothing + @assert off == 0 + @assert Base.aligned_sizeof(Ty) == Base.aligned_sizeof(Ty2) + end + create_recursive_stores(B, Ty2, prev3, count) + end + end + nothing + end end end @@ -1598,62 +1594,62 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie count = nothing if !has arg = V - if isa(arg, LLVM.CallInst) - fn = LLVM.called_operand(arg) - nm = "" - if isa(fn, LLVM.Function) - nm = LLVM.name(fn) - end - - # Type tag is arg 3 - if nm == "julia.gc_alloc_obj" || - nm == "jl_gc_alloc_typed" || - nm == "ijl_gc_alloc_typed" - totalsize = operands(arg)[2] - - @assert value_type(totalsize) isa LLVM.IntegerType - - arg = operands(arg)[3] - - if isa(arg, LLVM.CallInst) - fn = LLVM.called_operand(arg) - nm = "" - if isa(fn, LLVM.Function) - nm = LLVM.name(fn) - end - if LLVM.callconv(arg) == 37 || nm == "julia.call" - index = 1 - if LLVM.callconv(arg) != 37 - fn = first(operands(arg)) - nm = LLVM.name(fn) - index += 1 - end - if nm == "jl_f_apply_type" || nm == "ijl_f_apply_type" - index += 1 - found = Any[] - legal, Ty = absint(operands(arg)[index], partial) - if legal && Ty == NTuple - legal, Ty = absint(operands(arg)[index+2]) - if legal - # count should represent {the total size in bytes, the aligned size of each element} - B = LLVM.IRBuilder() - position!(B, V) - alignsize = LLVM.ConstantInt(value_type(totalsize), Base.aligned_sizeof(Ty)) - count = (totalsize, alignsize) - has = true - end - end - end - end - end - end - end - - - if !has + if isa(arg, LLVM.CallInst) + fn = LLVM.called_operand(arg) + nm = "" + if isa(fn, LLVM.Function) + nm = LLVM.name(fn) + end + + # Type tag is arg 3 + if nm == "julia.gc_alloc_obj" || + nm == "jl_gc_alloc_typed" || + nm == "ijl_gc_alloc_typed" + totalsize = operands(arg)[2] + + @assert value_type(totalsize) isa LLVM.IntegerType + + arg = operands(arg)[3] + + if isa(arg, LLVM.CallInst) + fn = LLVM.called_operand(arg) + nm = "" + if isa(fn, LLVM.Function) + nm = LLVM.name(fn) + end + if LLVM.callconv(arg) == 37 || nm == "julia.call" + index = 1 + if LLVM.callconv(arg) != 37 + fn = first(operands(arg)) + nm = LLVM.name(fn) + index += 1 + end + if nm == "jl_f_apply_type" || nm == "ijl_f_apply_type" + index += 1 + found = Any[] + legal, Ty = absint(operands(arg)[index], partial) + if legal && Ty == NTuple + legal, Ty = absint(operands(arg)[index+2]) + if legal + # count should represent {the total size in bytes, the aligned size of each element} + B = LLVM.IRBuilder() + position!(B, V) + alignsize = LLVM.ConstantInt(value_type(totalsize), Base.aligned_sizeof(Ty)) + count = (totalsize, alignsize) + has = true + end + end + end + end + end + end + end + + + if !has fn = LLVM.parent(LLVM.parent(V)) - throw(AssertionError("$(string(fn))\n Allocation could not have its type statically determined $(string(V))")) - end + throw(AssertionError("$(string(fn))\n Allocation could not have its type statically determined $(string(V))")) + end end if mode == API.DEM_ReverseModePrimal || @@ -1667,9 +1663,9 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie operands(V)[3] = unsafe_to_llvm(B, Base.RefValue{Ty}) end end - + if Base.datatype_pointerfree(Ty) - return + return end if mode == API.DEM_ForwardMode && (used || idx != 0) @@ -1691,7 +1687,7 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie B = LLVM.IRBuilder() position!(B, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(prev))) - create_recursive_stores(B, Ty, prev, count) + create_recursive_stores(B, Ty, prev, count) end if (mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeCombined) && used # Zero any jlvalue_t inner elements of preceeding allocation. @@ -1715,8 +1711,8 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie # Julia could decide to dead store eliminate the memset (not being read before the store of jlvaluet'), resulting in an error B = LLVM.IRBuilder() position!(B, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(V))) - - create_recursive_stores(B, Ty, V, count) + + create_recursive_stores(B, Ty, V, count) end nothing @@ -1786,7 +1782,7 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D LLVMType, jlType, )] - + addedvals = LLVM.Value[] while length(todo) != 0 path, ty, jlty = popfirst!(todo) @@ -1826,9 +1822,9 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D typed_fieldtype(jlty, i) elseif !(jlty isa DataType) if eltype(ty) isa LLVM.PointerType && LLVM.addrspace(eltype(ty)) == 10 - Any + Any else - throw(AssertionError("jlty=$jlty ty=$ty")) + throw(AssertionError("jlty=$jlty ty=$ty")) end end npath = copy(path) @@ -1838,7 +1834,7 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D continue end if isa(ty, LLVM.VectorType) - @assert jlty isa DataType + @assert jlty isa DataType for i = 1:size(ty) npath = copy(path) push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i - 1)) @@ -1904,7 +1900,7 @@ function zero_allocation( name = "zeroType." * string(jlType) if atomic - name = name * ".atomic" + name = name * ".atomic" end wrapper_f = LLVM.Function( @@ -1998,7 +1994,7 @@ function julia_allocator(B::LLVM.IRBuilder, @nospecialize(LLVMType::LLVM.LLVMTyp TT = Compiler.tape_type(LLVMType) if esizeof(TT) != convert(Int, AlignedSize) GPUCompiler.@safe_error "Enzyme aligned size and Julia size disagree" AlignedSize = - convert(Int, AlignedSize) esizeof(TT) fieldtypes(TT) LLVMType=strip(string(LLVMType)) + convert(Int, AlignedSize) esizeof(TT) fieldtypes(TT) LLVMType = strip(string(LLVMType)) emit_error(B, nothing, "Enzyme: Tape allocation failed.") # TODO: Pick appropriate orig return LLVM.API.LLVMValueRef(LLVM.UndefValue(LLVMType).ref) end @@ -2137,7 +2133,7 @@ function emit_inacterror(B::LLVM.API.LLVMBuilderRef, V::LLVM.API.LLVMValueRef, o funcT = LLVM.FunctionType( LLVM.VoidType(), LLVMType[LLVM.PointerType(LLVM.Int8Type())], - vararg = true, + vararg=true, ) func, _ = get_function!(mod, "jl_errorf", funcT, LLVM.Attribute[EnumAttribute("noreturn")]) @@ -2151,7 +2147,7 @@ include("rules/llvmrules.jl") function add_one_in_place(x) if x isa Base.RefValue x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x)))) - elseif x isa (Array{T,0} where T) + elseif x isa (Array{T,0} where {T}) x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x)))) else throw(EnzymeNonScalarReturnException(x, "")) @@ -2329,7 +2325,7 @@ function enzyme_extract_world(fn::LLVM.Function)::UInt throw(AssertionError("Enzyme: could not find world in $(string(fn))")) end -function enzyme_custom_extract_mi(orig::LLVM.CallInst, error::Bool = true) +function enzyme_custom_extract_mi(orig::LLVM.CallInst, error::Bool=true) operand = LLVM.called_operand(orig) if isa(operand, LLVM.Function) return enzyme_custom_extract_mi(operand::LLVM.Function, error) @@ -2339,7 +2335,7 @@ function enzyme_custom_extract_mi(orig::LLVM.CallInst, error::Bool = true) return nothing, nothing end -function enzyme_custom_extract_mi(orig::LLVM.Function, error::Bool = true) +function enzyme_custom_extract_mi(orig::LLVM.Function, error::Bool=true) mi = nothing RT = nothing for fattr in collect(function_attributes(orig)) @@ -2360,7 +2356,7 @@ function enzyme_custom_extract_mi(orig::LLVM.Function, error::Bool = true) return mi, RT end -function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error::Bool = true) +function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error::Bool=true) ty = nothing byref = nothing for fattr in collect(parameter_attributes(fn, idx)) @@ -2399,7 +2395,7 @@ function enzyme!( parallel::Bool, @nospecialize(actualRetType::Type), wrap::Bool, - @nospecialize(modifiedBetween::NTuple{N, Bool} where N), + @nospecialize(modifiedBetween::NTuple{N,Bool} where {N}), returnPrimal::Bool, @nospecialize(expectedTapeType::Type), loweredArgs::Set{Int}, @@ -2481,14 +2477,14 @@ function enzyme!( push!(args_known_values, API.IntList()) end if length(uncacheable_args) != length(collect(parameters(primalf))) - msg = sprint() do io - println(io, "length(uncacheable_args) != length(collect(parameters(primalf)))", TT) - println(io, "TT=", TT) - println(io, "modifiedBetween=", modifiedBetween) - println(io, "uncacheable_args=", uncacheable_args) - println(io, "primal", string(primalf)) - end - throw(AssertionError(msg)) + msg = sprint() do io + println(io, "length(uncacheable_args) != length(collect(parameters(primalf)))", TT) + println(io, "TT=", TT) + println(io, "modifiedBetween=", modifiedBetween) + println(io, "uncacheable_args=", uncacheable_args) + println(io, "primal", string(primalf)) + end + throw(AssertionError(msg)) end @assert length(args_typeInfo) == length(collect(parameters(primalf))) @@ -2506,266 +2502,266 @@ function enzyme!( enzyme_context = EnzymeContext() GC.@preserve enzyme_context begin - LLVM.@dispose logic = Logic(enzyme_context) begin + LLVM.@dispose logic = Logic(enzyme_context) begin - TA = TypeAnalysis(logic) + TA = TypeAnalysis(logic) - retTT = if !isa(actualRetType, Union) && - actualRetType <: Tuple && - in(Any, actualRetType.parameters) - TypeTree() - else - typeTree = typetree(actualRetType, ctx, dl, seen) - if !isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType) - typeTree = copy(typeTree) - merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) - only!(typeTree, -1) - end - typeTree - end + retTT = if !isa(actualRetType, Union) && + actualRetType <: Tuple && + in(Any, actualRetType.parameters) + TypeTree() + else + typeTree = typetree(actualRetType, ctx, dl, seen) + if !isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType) + typeTree = copy(typeTree) + merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) + only!(typeTree, -1) + end + typeTree + end - typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values) + typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values) - TapeType = Cvoid + TapeType = Cvoid - if mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient - returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType)) - shadowReturnUsed = - returnUsed && ( - retType == API.DFT_DUP_ARG || - retType == API.DFT_DUP_NONEED || - rt <: MixedDuplicated || - rt <: BatchMixedDuplicated - ) - returnUsed &= returnPrimal - augmented = API.EnzymeCreateAugmentedPrimal( - logic, - primalf, - retType, - args_activity, - TA, - returnUsed, #=returnUsed=# - shadowReturnUsed, #=shadowReturnUsed=# - typeInfo, - uncacheable_args, - false, - runtimeActivity, - strongZero, - width, - parallel, - ) #=atomicAdd=# + if mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient + returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType)) + shadowReturnUsed = + returnUsed && ( + retType == API.DFT_DUP_ARG || + retType == API.DFT_DUP_NONEED || + rt <: MixedDuplicated || + rt <: BatchMixedDuplicated + ) + returnUsed &= returnPrimal + augmented = API.EnzymeCreateAugmentedPrimal( + logic, + primalf, + retType, + args_activity, + TA, + returnUsed, #=returnUsed=# + shadowReturnUsed, #=shadowReturnUsed=# + typeInfo, + uncacheable_args, + false, + runtimeActivity, + strongZero, + width, + parallel, + ) #=atomicAdd=# + + # 2. get new_primalf and tape + augmented_primalf = + LLVM.Function(API.EnzymeExtractFunctionFromAugmentation(augmented)) + tape = API.EnzymeExtractTapeTypeFromAugmentation(augmented) + utape = API.EnzymeExtractUnderlyingTapeTypeFromAugmentation(augmented) + if utape != C_NULL + TapeType = EnzymeTapeToLoad{Compiler.tape_type(LLVMType(utape))} + tape = utape + elseif tape != C_NULL + TapeType = Compiler.tape_type(LLVMType(tape)) + else + TapeType = Cvoid + end + if expectedTapeType !== UnknownTapeType + @assert expectedTapeType === TapeType + end - # 2. get new_primalf and tape - augmented_primalf = - LLVM.Function(API.EnzymeExtractFunctionFromAugmentation(augmented)) - tape = API.EnzymeExtractTapeTypeFromAugmentation(augmented) - utape = API.EnzymeExtractUnderlyingTapeTypeFromAugmentation(augmented) - if utape != C_NULL - TapeType = EnzymeTapeToLoad{Compiler.tape_type(LLVMType(utape))} - tape = utape - elseif tape != C_NULL - TapeType = Compiler.tape_type(LLVMType(tape)) - else - TapeType = Cvoid - end - if expectedTapeType !== UnknownTapeType - @assert expectedTapeType === TapeType - end - - if wrap - augmented_primalf = create_abi_wrapper( - augmented_primalf, - TT, - rt, - actualRetType, - API.DEM_ReverseModePrimal, - augmented, - width, - returnPrimal, - shadow_init, - world, - interp, - runtimeActivity, - ) - end + if wrap + augmented_primalf = create_abi_wrapper( + augmented_primalf, + TT, + rt, + actualRetType, + API.DEM_ReverseModePrimal, + augmented, + width, + returnPrimal, + shadow_init, + world, + interp, + runtimeActivity, + ) + end - # TODOs: - # 1. Handle mutable or !pointerfree arguments by introducing caching - # + specifically by setting uncacheable_args[i] = true + # TODOs: + # 1. Handle mutable or !pointerfree arguments by introducing caching + # + specifically by setting uncacheable_args[i] = true + + adjointf = LLVM.Function( + API.EnzymeCreatePrimalAndGradient( + logic, + primalf, + retType, + args_activity, + TA, + false, + false, + API.DEM_ReverseModeGradient, + runtimeActivity, + strongZero, + width, #=mode=# + tape, + false, + typeInfo, #=forceAnonymousTape=# + uncacheable_args, + augmented, + parallel, + ), + ) #=atomicAdd=# + if wrap + adjointf = create_abi_wrapper( + adjointf, + TT, + rt, + actualRetType, + API.DEM_ReverseModeGradient, + augmented, + width, + false, + shadow_init, + world, + interp, + runtimeActivity + ) #=returnPrimal=# + end + elseif mode == API.DEM_ReverseModeCombined + returnUsed = !isghostty(actualRetType) + returnUsed &= returnPrimal + adjointf = LLVM.Function( + API.EnzymeCreatePrimalAndGradient( + logic, + primalf, + retType, + args_activity, + TA, + returnUsed, + false, + API.DEM_ReverseModeCombined, + runtimeActivity, + strongZero, + width, #=mode=# + C_NULL, + false, + typeInfo, #=forceAnonymousTape=# + uncacheable_args, + C_NULL, + parallel, + ), + ) #=atomicAdd=# + augmented_primalf = nothing + if wrap + adjointf = create_abi_wrapper( + adjointf, + TT, + rt, + actualRetType, + API.DEM_ReverseModeCombined, + nothing, + width, + returnPrimal, + shadow_init, + world, + interp, + runtimeActivity + ) + end + elseif mode == API.DEM_ForwardMode + returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType)) - adjointf = LLVM.Function( - API.EnzymeCreatePrimalAndGradient( - logic, - primalf, - retType, - args_activity, - TA, - false, - false, - API.DEM_ReverseModeGradient, - runtimeActivity, - strongZero, - width, #=mode=# - tape, - false, - typeInfo, #=forceAnonymousTape=# - uncacheable_args, - augmented, - parallel, - ), - ) #=atomicAdd=# - if wrap - adjointf = create_abi_wrapper( - adjointf, - TT, - rt, - actualRetType, - API.DEM_ReverseModeGradient, - augmented, - width, - false, - shadow_init, - world, - interp, - runtimeActivity - ) #=returnPrimal=# - end - elseif mode == API.DEM_ReverseModeCombined - returnUsed = !isghostty(actualRetType) - returnUsed &= returnPrimal - adjointf = LLVM.Function( - API.EnzymeCreatePrimalAndGradient( - logic, - primalf, - retType, - args_activity, - TA, - returnUsed, - false, - API.DEM_ReverseModeCombined, - runtimeActivity, - strongZero, - width, #=mode=# - C_NULL, - false, - typeInfo, #=forceAnonymousTape=# - uncacheable_args, - C_NULL, - parallel, - ), - ) #=atomicAdd=# - augmented_primalf = nothing - if wrap - adjointf = create_abi_wrapper( - adjointf, - TT, - rt, - actualRetType, - API.DEM_ReverseModeCombined, - nothing, - width, - returnPrimal, - shadow_init, - world, - interp, - runtimeActivity - ) - end - elseif mode == API.DEM_ForwardMode - returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType)) + literal_rt = eltype(rt) - literal_rt = eltype(rt) + if !isghostty(literal_rt) && runtimeActivity && GPUCompiler.deserves_argbox(actualRetType) && !GPUCompiler.deserves_argbox(literal_rt) + else + returnUsed &= returnPrimal + end - if !isghostty(literal_rt) && runtimeActivity && GPUCompiler.deserves_argbox(actualRetType) && !GPUCompiler.deserves_argbox(literal_rt) - else - returnUsed &= returnPrimal - end - - adjointf = LLVM.Function( - API.EnzymeCreateForwardDiff( - logic, - primalf, - retType, - args_activity, - TA, - returnUsed, - API.DEM_ForwardMode, - runtimeActivity, - strongZero, - width, #=mode=# - C_NULL, - typeInfo, #=additionalArg=# - uncacheable_args, - ), - ) - augmented_primalf = nothing - if wrap - pf = adjointf - adjointf = create_abi_wrapper( - adjointf, - TT, - rt, - actualRetType, - API.DEM_ForwardMode, - nothing, - width, - returnPrimal, - shadow_init, - world, - interp, - runtimeActivity - ) - end - else - @assert "Unhandled derivative mode", mode - end - if DumpPostWrap[] - API.EnzymeDumpModuleRef(mod.ref) - end + adjointf = LLVM.Function( + API.EnzymeCreateForwardDiff( + logic, + primalf, + retType, + args_activity, + TA, + returnUsed, + API.DEM_ForwardMode, + runtimeActivity, + strongZero, + width, #=mode=# + C_NULL, + typeInfo, #=additionalArg=# + uncacheable_args, + ), + ) + augmented_primalf = nothing + if wrap + pf = adjointf + adjointf = create_abi_wrapper( + adjointf, + TT, + rt, + actualRetType, + API.DEM_ForwardMode, + nothing, + width, + returnPrimal, + shadow_init, + world, + interp, + runtimeActivity + ) + end + else + @assert "Unhandled derivative mode", mode + end + if DumpPostWrap[] + API.EnzymeDumpModuleRef(mod.ref) + end - # Rewrite enzyme_ignore_derivatives functions to the identity of their first argument. - to_delete = LLVM.Function[] - for fn in functions(mod) - if startswith(name(fn), "__enzyme_ignore_derivatives") - push!(to_delete, fn) - to_delete_inst = LLVM.CallInst[] - for u in LLVM.uses(fn) - ci = LLVM.user(u) - @assert isa(ci, LLVM.CallInst) - LLVM.replace_uses!(ci, operands(ci)[1]) - push!(to_delete_inst, ci) + # Rewrite enzyme_ignore_derivatives functions to the identity of their first argument. + to_delete = LLVM.Function[] + for fn in functions(mod) + if startswith(name(fn), "__enzyme_ignore_derivatives") + push!(to_delete, fn) + to_delete_inst = LLVM.CallInst[] + for u in LLVM.uses(fn) + ci = LLVM.user(u) + @assert isa(ci, LLVM.CallInst) + LLVM.replace_uses!(ci, operands(ci)[1]) + push!(to_delete_inst, ci) + end + for ci in to_delete_inst + LLVM.erase!(ci) + end + end end - for ci in to_delete_inst - LLVM.erase!(ci) + for fn in to_delete + LLVM.erase!(fn) end - end - end - for fn in to_delete - LLVM.erase!(fn) - end - LLVM.verify(mod) + LLVM.verify(mod) - API.EnzymeLogicErasePreprocessedFunctions(logic) - adjointfname = adjointf == nothing ? nothing : LLVM.name(adjointf) - augmented_primalfname = - augmented_primalf == nothing ? nothing : LLVM.name(augmented_primalf) - for f in collect(functions(mod)) - API.EnzymeFixupBatchedJuliaCallingConvention(f) - end - ModulePassManager() do pm - dce!(pm) - LLVM.run!(pm, mod) - end - fix_decayaddr!(mod) - adjointf = adjointf == nothing ? nothing : functions(mod)[adjointfname] - augmented_primalf = - augmented_primalf == nothing ? nothing : functions(mod)[augmented_primalfname] - if DumpPostEnzyme[] - API.EnzymeDumpModuleRef(mod.ref) - end + API.EnzymeLogicErasePreprocessedFunctions(logic) + adjointfname = adjointf == nothing ? nothing : LLVM.name(adjointf) + augmented_primalfname = + augmented_primalf == nothing ? nothing : LLVM.name(augmented_primalf) + for f in collect(functions(mod)) + API.EnzymeFixupBatchedJuliaCallingConvention(f) + end + ModulePassManager() do pm + dce!(pm) + LLVM.run!(pm, mod) + end + fix_decayaddr!(mod) + adjointf = adjointf == nothing ? nothing : functions(mod)[adjointfname] + augmented_primalf = + augmented_primalf == nothing ? nothing : functions(mod)[augmented_primalfname] + if DumpPostEnzyme[] + API.EnzymeDumpModuleRef(mod.ref) + end - return adjointf, augmented_primalf, TapeType - end # @dispose logic + return adjointf, augmented_primalf, TapeType + end # @dispose logic end # GC.preserve enzyme_context end @@ -2890,7 +2886,7 @@ function create_abi_wrapper( if is_adjoint NT = Tuple{ActiveRetTypes...} if any( - any_jltypes(convert(LLVM.LLVMType, b; allow_boxed = true)) for + any_jltypes(convert(LLVM.LLVMType, b; allow_boxed=true)) for b in ActiveRetTypes ) NT = AnonymousStruct(NT) @@ -2926,7 +2922,7 @@ function create_abi_wrapper( dretTy = LLVM.LLVMType( API.EnzymeGetShadowType( width, - convert(LLVMType, actualRetType; allow_boxed = !(rettype <: Active)), + convert(LLVMType, actualRetType; allow_boxed=!(rettype <: Active)), ), ) push!(T_wrapperargs, dretTy) @@ -2980,7 +2976,7 @@ function create_abi_wrapper( rty = if Base.isconcretetype(literal_rt) Base.RefValue{literal_rt} else - (Base.RefValue{T} where T <: literal_rt) + (Base.RefValue{T} where {T<:literal_rt}) end if width == 1 push!(sret_types, rty) @@ -3016,7 +3012,7 @@ function create_abi_wrapper( combinedReturn = if any( - any_jltypes(convert(LLVM.LLVMType, T; allow_boxed = true)) for T in sret_types + any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types ) AnonymousStruct(Tuple{sret_types...}) else @@ -3059,7 +3055,7 @@ function create_abi_wrapper( end if tape != C_NULL tape = LLVM.LLVMType(tape) - jltape = convert(LLVM.LLVMType, Compiler.tape_type(tape); allow_boxed = true) + jltape = convert(LLVM.LLVMType, Compiler.tape_type(tape); allow_boxed=true) push!(T_wrapperargs, jltape) else needs_tape = false @@ -3121,7 +3117,7 @@ function create_abi_wrapper( llty = value_type(params[i]) - convty = convert(LLVMType, T′; allow_boxed = true) + convty = convert(LLVMType, T′; allow_boxed=true) if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType)) @assert Base.isconcretetype(T′) @@ -3665,9 +3661,8 @@ function lower_convention( entry_f::LLVM.Function, @nospecialize(actualRetType::Type), @nospecialize(RetActivity::Type), - @nospecialize(TT::Union{Type, Nothing}), + @nospecialize(TT::Union{Type,Nothing}), run_enzyme::Bool, - world::UInt ) entry_ft = LLVM.function_type(entry_f) @@ -3691,11 +3686,11 @@ function lower_convention( returnRoots = returnRoots !== nothing loweredReturn = RetActivity <: Active && !allocatedinline(actualRetType) - if (RetActivity <: Active || RetActivity <: MixedDuplicated || RetActivity <: BatchMixedDuplicated) && (allocatedinline(actualRetType) != allocatedinline(eltype(RetActivity))) - @assert !allocatedinline(actualRetType) - loweredReturn = true + if (RetActivity <: Active || RetActivity <: MixedDuplicated || RetActivity <: BatchMixedDuplicated) && (allocatedinline(actualRetType) != allocatedinline(eltype(RetActivity))) + @assert !allocatedinline(actualRetType) + loweredReturn = true end - + expected_RT = Nothing if loweredReturn @assert !sret @@ -3897,7 +3892,7 @@ function lower_convention( if RetActivity <: Const metadata(sretPtr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end - + typeTree = copy(typetree(actualRetType, ctx, dl, seen)) merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) only!(typeTree, -1) @@ -3937,7 +3932,7 @@ function lower_convention( metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end ctx = LLVM.context(entry_f) - + typeTree = copy(typetree(arg.typ, ctx, dl, seen)) merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) only!(typeTree, -1) @@ -4181,7 +4176,7 @@ function lower_convention( position!(builder, failure) - emit_error(builder, nothing, "Expected return type of primal to be "*string(expected_RT)*" but did not find a value of that type") + emit_error(builder, nothing, "Expected return type of primal to be " * string(expected_RT) * " but did not find a value of that type") unreachable!(builder) else push!( @@ -4227,9 +4222,7 @@ function lower_convention( attributes, StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(rt)))), ) - if EnzymeRules.has_easy_rule_from_sig(Interpreter.simplify_kw(mi.specTypes); world) - push!(attributes, LLVM.StringAttribute("enzyme_LocalReadOnlyOrThrow")) - end + for prev in collect(function_attributes(entry_f)) if kind(prev) == kind(StringAttribute("enzyme_ta_norecur")) push!(attributes, prev) @@ -4410,7 +4403,7 @@ function lower_convention( println(io, string(mod)) println( io, - LVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction), + LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction), ) println(io, string(wrapper_f)) println(io, "Broken function") @@ -4422,7 +4415,7 @@ end using Random # returns arg, return -function no_type_setting(@nospecialize(specTypes::Type{<:Tuple}); world = nothing) +function no_type_setting(@nospecialize(specTypes::Type{<:Tuple}); world=nothing) # Even though the julia type here is ptr{int8}, the actual data can be something else if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_simd) return (true, false) @@ -4441,7 +4434,7 @@ const DumpPreOpt = Ref(false) function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeTarget}) @assert output == :llvm - + config = job.config params = config.params @@ -4479,14 +4472,14 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT primal_config = CompilerConfig( primal_target, primal_params; - toplevel = config.toplevel, - always_inline = config.always_inline, - kernel = false, - libraries = true, - optimize = false, - cleanup = false, - only_entry = false, - validate = false, + toplevel=config.toplevel, + always_inline=config.always_inline, + kernel=false, + libraries=true, + optimize=false, + cleanup=false, + only_entry=false, + validate=false, # ??? entry_abi ) primal_job = CompilerJob(primal, primal_config, job.world) @@ -4727,7 +4720,6 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT job.config.params.rt, TT, params.run_enzyme, - job.world ) end @@ -4739,13 +4731,13 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT parallel = false process_module = false device_module = false - if primal_target isa GPUCompiler.NativeCompilerTarget - parallel = Base.Threads.nthreads() > 1 + if primal_target isa GPUCompiler.NativeCompilerTarget + parallel = Base.Threads.nthreads() > 1 else # All other targets are GPU targets parallel = true device_module = true - + if primal_target isa GPUCompiler.GCNCompilerTarget || primal_target isa GPUCompiler.MetalCompilerTarget process_module = true @@ -4790,7 +4782,7 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT ctx = LLVM.context(mod) for f in functions(mod), bb in blocks(f), inst in instructions(bb) fn = isa(inst, LLVM.CallInst) ? LLVM.called_operand(inst) : nothing - + if !API.HasFromStack(inst) && isa(inst, LLVM.AllocaInst) calluse = nothing @@ -4832,7 +4824,7 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT if !API.HasFromStack(inst) && ((isa(inst, LLVM.CallInst) && - (!isa(fn, LLVM.Function) || isempty(blocks(fn))) ) || isa(inst, LLVM.LoadInst) || isa(inst, LLVM.AllocaInst) || isa(inst, LLVM.ExtractValueInst)) + (!isa(fn, LLVM.Function) || isempty(blocks(fn)))) || isa(inst, LLVM.LoadInst) || isa(inst, LLVM.AllocaInst) || isa(inst, LLVM.ExtractValueInst)) legal, source_typ, byref = abs_typeof(inst) codegen_typ = value_type(inst) if legal @@ -4861,14 +4853,15 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT metadata(inst)["enzyme_type"] = to_md(ec, ctx) metadata(inst)["enzymejl_source_type_$(source_typ)"] = MDNode(LLVM.Metadata[]) metadata(inst)["enzymejl_byref_$(byref)"] = MDNode(LLVM.Metadata[]) - -@static if VERSION < v"1.11-" -else - legal2, obj = absint(inst) - if legal2 obj isa Memory && obj == typeof(obj).instance - metadata(inst)["nonnull"] = MDNode(LLVM.Metadata[]) + + @static if VERSION < v"1.11-" + else + legal2, obj = absint(inst) + if legal2 + obj isa Memory && obj == typeof(obj).instance + metadata(inst)["nonnull"] = MDNode(LLVM.Metadata[]) + end end -end end @@ -5101,7 +5094,7 @@ end adjointf, augmented_primalf, TapeType = enzyme!( job, - interp, + interp, mod, primalf, TT, @@ -5268,12 +5261,12 @@ end isempty(LLVM.blocks(fn)) && continue linkage!(fn, LLVM.API.LLVMLinkerPrivateLinkage) end - + delete!(mod_to_edges, mod) use_primal = mode == API.DEM_ReverseModePrimal entry = use_primal ? augmented_primalf : adjointf - return mod, (; adjointf, augmented_primalf, entry, compiled = meta.compiled, TapeType, edges) + return mod, (; adjointf, augmented_primalf, entry, compiled=meta.compiled, TapeType, edges) end # Compiler result @@ -5426,265 +5419,265 @@ end ::Type{TapeType}, args::Vararg{Any,N}, ) where {RawCall,PT,FA,T,RT,TapeType,N,CC,width,returnPrimal} - F = eltype(FA) - is_forward = - CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk - is_adjoint = CC <: AdjointThunk || CC <: CombinedAdjointThunk - is_split = CC <: AdjointThunk || CC <: AugmentedForwardThunk - needs_tape = CC <: AdjointThunk - - argtt = tt.parameters[1] - rettype = rt.parameters[1] - argtypes = DataType[argtt.parameters...] - argexprs = Union{Expr,Symbol}[:(args[$i]) for i = 1:N] - - if false && CC <: PrimalErrorThunk - primargs = [ - quote - convert($(eltype(T)), $(argexprs[i]).val) - end for (i, T) in enumerate(argtypes) - ] - return quote - fn.val($(primargs...)) - error( - "Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up", - ) - end + F = eltype(FA) + is_forward = + CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk + is_adjoint = CC <: AdjointThunk || CC <: CombinedAdjointThunk + is_split = CC <: AdjointThunk || CC <: AugmentedForwardThunk + needs_tape = CC <: AdjointThunk + + argtt = tt.parameters[1] + rettype = rt.parameters[1] + argtypes = DataType[argtt.parameters...] + argexprs = Union{Expr,Symbol}[:(args[$i]) for i = 1:N] + + if false && CC <: PrimalErrorThunk + primargs = [ + quote + convert($(eltype(T)), $(argexprs[i]).val) + end for (i, T) in enumerate(argtypes) + ] + return quote + fn.val($(primargs...)) + error( + "Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up", + ) end + end - if !RawCall && !(CC <: PrimalErrorThunk) - if rettype <: Active || - rettype <: MixedDuplicated || - rettype <: BatchMixedDuplicated - if length(argtypes) + is_adjoint + needs_tape != length(argexprs) - return quote - throw(MethodError($CC(fptr), (fn, args...))) - end + if !RawCall && !(CC <: PrimalErrorThunk) + if rettype <: Active || + rettype <: MixedDuplicated || + rettype <: BatchMixedDuplicated + if length(argtypes) + is_adjoint + needs_tape != length(argexprs) + return quote + throw(MethodError($CC(fptr), (fn, args...))) end - elseif rettype <: Const - if length(argtypes) + needs_tape != length(argexprs) - return quote - throw(MethodError($CC(fptr), (fn, args...))) - end + end + elseif rettype <: Const + if length(argtypes) + needs_tape != length(argexprs) + return quote + throw(MethodError($CC(fptr), (fn, args...))) end - else - if length(argtypes) + needs_tape != length(argexprs) - return quote - throw(MethodError($CC(fptr), (fn, args...))) - end + end + else + if length(argtypes) + needs_tape != length(argexprs) + return quote + throw(MethodError($CC(fptr), (fn, args...))) end end end + end - types = DataType[] + types = DataType[] - if !(rettype <: Const) && ( - isghostty(eltype(rettype)) || - Core.Compiler.isconstType(eltype(rettype)) || - eltype(rettype) === DataType - ) - rrt = eltype(rettype) - error("Return type `$rrt` not marked Const, but is ghost or const type.") - end + if !(rettype <: Const) && ( + isghostty(eltype(rettype)) || + Core.Compiler.isconstType(eltype(rettype)) || + eltype(rettype) === DataType + ) + rrt = eltype(rettype) + error("Return type `$rrt` not marked Const, but is ghost or const type.") + end - sret_types = Type[] # Julia types of all returned variables - # By ref values we create and need to preserve - ccexprs = Union{Expr,Symbol}[] # The expressions passed to the `llvmcall` + sret_types = Type[] # Julia types of all returned variables + # By ref values we create and need to preserve + ccexprs = Union{Expr,Symbol}[] # The expressions passed to the `llvmcall` - if !isghostty(F) && !Core.Compiler.isconstType(F) - isboxed = GPUCompiler.deserves_argbox(F) - argexpr = :(fn.val) + if !isghostty(F) && !Core.Compiler.isconstType(F) + isboxed = GPUCompiler.deserves_argbox(F) + argexpr = :(fn.val) - if isboxed - push!(types, Any) - else - push!(types, F) - end + if isboxed + push!(types, Any) + else + push!(types, F) + end - push!(ccexprs, argexpr) - if (FA <: Active) - return quote - error("Cannot have function with Active annotation, $FA") + push!(ccexprs, argexpr) + if (FA <: Active) + return quote + error("Cannot have function with Active annotation, $FA") + end + elseif !(FA <: Const) + argexpr = :(fn.dval) + F_ABI = F + if width == 1 + if (FA <: MixedDuplicated) + push!(types, Any) + else + push!(types, F_ABI) end - elseif !(FA <: Const) - argexpr = :(fn.dval) - F_ABI = F - if width == 1 - if (FA <: MixedDuplicated) - push!(types, Any) - else - push!(types, F_ABI) - end + else + if F_ABI <: BatchMixedDuplicated + F_ABI = Base.RefValue{F_ABI} + end + F_ABI = NTuple{width,F_ABI} + isboxedvec = GPUCompiler.deserves_argbox(F_ABI) + if isboxedvec + push!(types, Any) else - if F_ABI <: BatchMixedDuplicated - F_ABI = Base.RefValue{F_ABI} - end - F_ABI = NTuple{width, F_ABI} - isboxedvec = GPUCompiler.deserves_argbox(F_ABI) - if isboxedvec - push!(types, Any) - else - push!(types, F_ABI) - end + push!(types, F_ABI) end - push!(ccexprs, argexpr) end + push!(ccexprs, argexpr) end + end - i = 1 - ActiveRetTypes = Type[] + i = 1 + ActiveRetTypes = Type[] - for T in argtypes - source_typ = eltype(T) + for T in argtypes + source_typ = eltype(T) - expr = argexprs[i] - i += 1 - if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) - @assert T <: Const - if is_adjoint - push!(ActiveRetTypes, Nothing) - end - continue + expr = argexprs[i] + i += 1 + if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) + @assert T <: Const + if is_adjoint + push!(ActiveRetTypes, Nothing) end + continue + end - isboxed = GPUCompiler.deserves_argbox(source_typ) + isboxed = GPUCompiler.deserves_argbox(source_typ) - argexpr = if RawCall - expr + argexpr = if RawCall + expr + else + Expr(:., expr, QuoteNode(:val)) + end + + if isboxed + push!(types, Any) + else + push!(types, source_typ) + end + + push!(ccexprs, argexpr) + + if T <: Const || T <: BatchDuplicatedFunc + if is_adjoint + push!(ActiveRetTypes, Nothing) + end + continue + end + if CC <: PrimalErrorThunk + continue + end + if T <: Active + if is_adjoint + if width == 1 + push!(ActiveRetTypes, source_typ) + else + push!(ActiveRetTypes, NTuple{width,source_typ}) + end + end + elseif T <: Duplicated || T <: DuplicatedNoNeed + if RawCall + argexpr = argexprs[i] + i += 1 else - Expr(:., expr, QuoteNode(:val)) + argexpr = Expr(:., expr, QuoteNode(:dval)) end - if isboxed push!(types, Any) else push!(types, source_typ) end - + if is_adjoint + push!(ActiveRetTypes, Nothing) + end push!(ccexprs, argexpr) - - if T <: Const || T <: BatchDuplicatedFunc - if is_adjoint - push!(ActiveRetTypes, Nothing) - end - continue + elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed + if RawCall + argexpr = argexprs[i] + i += 1 + else + argexpr = Expr(:., expr, QuoteNode(:dval)) end - if CC <: PrimalErrorThunk - continue + isboxedvec = GPUCompiler.deserves_argbox(NTuple{width,source_typ}) + if isboxedvec + push!(types, Any) + else + push!(types, NTuple{width,source_typ}) end - if T <: Active - if is_adjoint - if width == 1 - push!(ActiveRetTypes, source_typ) - else - push!(ActiveRetTypes, NTuple{width,source_typ}) - end - end - elseif T <: Duplicated || T <: DuplicatedNoNeed - if RawCall - argexpr = argexprs[i] - i += 1 - else - argexpr = Expr(:., expr, QuoteNode(:dval)) - end - if isboxed - push!(types, Any) - else - push!(types, source_typ) - end - if is_adjoint - push!(ActiveRetTypes, Nothing) - end - push!(ccexprs, argexpr) - elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed - if RawCall - argexpr = argexprs[i] - i += 1 - else - argexpr = Expr(:., expr, QuoteNode(:dval)) - end - isboxedvec = GPUCompiler.deserves_argbox(NTuple{width,source_typ}) - if isboxedvec - push!(types, Any) - else - push!(types, NTuple{width,source_typ}) - end - if is_adjoint - push!(ActiveRetTypes, Nothing) - end - push!(ccexprs, argexpr) - elseif T <: MixedDuplicated - if RawCall - argexpr = argexprs[i] - i += 1 - else - argexpr = Expr(:., expr, QuoteNode(:dval)) - end + if is_adjoint + push!(ActiveRetTypes, Nothing) + end + push!(ccexprs, argexpr) + elseif T <: MixedDuplicated + if RawCall + argexpr = argexprs[i] + i += 1 + else + argexpr = Expr(:., expr, QuoteNode(:dval)) + end + push!(types, Any) + if is_adjoint + push!(ActiveRetTypes, Nothing) + end + push!(ccexprs, argexpr) + elseif T <: BatchMixedDuplicated + if RawCall + argexpr = argexprs[i] + i += 1 + else + argexpr = Expr(:., expr, QuoteNode(:dval)) + end + isboxedvec = + GPUCompiler.deserves_argbox(NTuple{width,Base.RefValue{source_typ}}) + if isboxedvec push!(types, Any) - if is_adjoint - push!(ActiveRetTypes, Nothing) - end - push!(ccexprs, argexpr) - elseif T <: BatchMixedDuplicated - if RawCall - argexpr = argexprs[i] - i += 1 - else - argexpr = Expr(:., expr, QuoteNode(:dval)) - end - isboxedvec = - GPUCompiler.deserves_argbox(NTuple{width,Base.RefValue{source_typ}}) - if isboxedvec - push!(types, Any) - else - push!(types, NTuple{width,Base.RefValue{source_typ}}) - end - if is_adjoint - push!(ActiveRetTypes, Nothing) - end - push!(ccexprs, argexpr) else - error("calling convention should be annotated, got $T") + push!(types, NTuple{width,Base.RefValue{source_typ}}) end + if is_adjoint + push!(ActiveRetTypes, Nothing) + end + push!(ccexprs, argexpr) + else + error("calling convention should be annotated, got $T") end + end - jlRT = eltype(rettype) - if typeof(jlRT) == UnionAll - # Future improvement, add type assertion on load - jlRT = DataType - end + jlRT = eltype(rettype) + if typeof(jlRT) == UnionAll + # Future improvement, add type assertion on load + jlRT = DataType + end - if is_sret_union(jlRT) - jlRT = Any - end + if is_sret_union(jlRT) + jlRT = Any + end - # API.DFT_OUT_DIFF - if is_adjoint - if rettype <: Active || - rettype <: MixedDuplicated || - rettype <: BatchMixedDuplicated - # TODO handle batch width - if rettype <: Active - @assert allocatedinline(jlRT) - end - j_drT = if width == 1 - jlRT - else - NTuple{width,jlRT} - end - push!(types, j_drT) - push!(ccexprs, argexprs[i]) - i += 1 + # API.DFT_OUT_DIFF + if is_adjoint + if rettype <: Active || + rettype <: MixedDuplicated || + rettype <: BatchMixedDuplicated + # TODO handle batch width + if rettype <: Active + @assert allocatedinline(jlRT) end - end - - if needs_tape - if !(isghostty(TapeType) || Core.Compiler.isconstType(TapeType)) - push!(types, TapeType) - push!(ccexprs, argexprs[i]) + j_drT = if width == 1 + jlRT + else + NTuple{width,jlRT} end + push!(types, j_drT) + push!(ccexprs, argexprs[i]) i += 1 end + end + + if needs_tape + if !(isghostty(TapeType) || Core.Compiler.isconstType(TapeType)) + push!(types, TapeType) + push!(ccexprs, argexprs[i]) + end + i += 1 + end ts_ctx = JuliaContext() ctx = context(ts_ctx) @@ -5694,7 +5687,7 @@ end if is_adjoint NT = Tuple{ActiveRetTypes...} if any( - any_jltypes(convert(LLVM.LLVMType, b; allow_boxed = true)) for + any_jltypes(convert(LLVM.LLVMType, b; allow_boxed=true)) for b in ActiveRetTypes ) NT = AnonymousStruct(NT) @@ -5724,7 +5717,7 @@ end rty = if Base.isconcretetype(jlRT) Base.RefValue{jlRT} else - (Base.RefValue{T} where T <: jlRT) + (Base.RefValue{T} where {T<:jlRT}) end push!(sret_types, rty) elseif rettype <: BatchDuplicated || rettype <: BatchDuplicatedNoNeed @@ -5733,7 +5726,7 @@ end rty = if Base.isconcretetype(jlRT) Base.RefValue{jlRT} else - (Base.RefValue{T} where T <: jlRT) + (Base.RefValue{T} where {T<:jlRT}) end push!(sret_types, AnonymousStruct(NTuple{width,rty})) elseif CC <: AugmentedForwardThunk @@ -5753,7 +5746,7 @@ end end # calls fptr - llvmtys = LLVMType[convert(LLVMType, x; allow_boxed = true) for x in types] + llvmtys = LLVMType[convert(LLVMType, x; allow_boxed=true) for x in types] T_void = convert(LLVMType, Nothing) @@ -5761,7 +5754,7 @@ end (CC <: PrimalErrorThunk && eltype(rettype) == Union{}) ? Union{} : Tuple{sret_types...} if any( - any_jltypes(convert(LLVM.LLVMType, T; allow_boxed = true)) for T in sret_types + any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types ) combinedReturn = AnonymousStruct(combinedReturn) end @@ -5905,7 +5898,7 @@ end # JIT ## -function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, edges::Vector{Any}, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String) +function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, edges::Vector{Any}, adjoint_name::String, @nospecialize(primal_name::Union{String,Nothing}), @nospecialize(TapeType), prepost::String) if job.config.params.ABI <: InlineABI return CompileResult( Val((Symbol(mod), Symbol(adjoint_name))), @@ -5950,7 +5943,7 @@ const DumpPrePostOpt = Ref(false) const DumpPostOpt = Ref(false) # actual compilation -function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, String, Union{String, Nothing}, Type, String} +function _thunk(job, postopt::Bool=true)::Tuple{LLVM.Module,Vector{Any},String,Union{String,Nothing},Type,String} config = CompilerConfig(job.config; optimize=false) job = CompilerJob(job.source, config, job.world) mod, meta = compile(:llvm, job) @@ -5997,7 +5990,7 @@ end const cache = Dict{UInt,CompileResult}() -const autodiff_cache = Dict{Ptr{Cvoid},Tuple{String, String}}() +const autodiff_cache = Dict{Ptr{Cvoid},Tuple{String,String}}() const cache_lock = ReentrantLock() @inline function cached_compilation(@nospecialize(job::CompilerJob))::CompileResult @@ -6026,20 +6019,20 @@ end @inline function thunkbase( mi::Core.MethodInstance, - World::Union{UInt, Nothing}, + World::Union{UInt,Nothing}, @nospecialize(FA::Type{<:Annotation}), @nospecialize(A::Type{<:Annotation}), @nospecialize(TT::Type), Mode::API.CDerivativeMode, width::Int, - @nospecialize(ModifiedBetween::(NTuple{N, Bool} where N)), + @nospecialize(ModifiedBetween::(NTuple{N,Bool} where {N})), ReturnPrimal::Bool, ShadowInit::Bool, @nospecialize(ABI::Type), ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, - edges::Union{Nothing, Vector{Any}} + edges::Union{Nothing,Vector{Any}} ) target = Compiler.EnzymeTarget() params = Compiler.EnzymeCompilerParams( @@ -6059,11 +6052,11 @@ end StrongZero ) #=abiwrap=# tmp_job = if World isa Nothing - jb = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false)) + jb = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) check_activity_cache_invalidations(jb.world) jb else - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false), World) + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) end interp = GPUCompiler.get_interpreter(tmp_job) @@ -6113,9 +6106,9 @@ end StrongZero ) #=abiwrap=# job = if World isa Nothing - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false)) + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) else - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false), World) + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) end # We need to use primal as the key, to lookup the right method # but need to mixin the hash of the adjoint to avoid cache collisions @@ -6239,31 +6232,31 @@ end function thunk end -function thunk_generator(world::UInt, source::Union{Method, LineNumberNode}, @nospecialize(FA::Type), @nospecialize(A::Type), @nospecialize(TT::Type), Mode::Enzyme.API.CDerivativeMode, Width::Int, @nospecialize(ModifiedBetween::(NTuple{N, Bool} where N)), ReturnPrimal::Bool, ShadowInit::Bool, @nospecialize(ABI::Type), ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, @nospecialize(self), @nospecialize(fakeworld), @nospecialize(fa::Type), @nospecialize(a::Type), @nospecialize(tt::Type), @nospecialize(mode::Type), @nospecialize(width::Type), @nospecialize(modifiedbetween::Type), @nospecialize(returnprimal::Type), @nospecialize(shadowinit::Type), @nospecialize(abi::Type), @nospecialize(erriffuncwritten::Type), @nospecialize(runtimeactivity::Type), @nospecialize(strongzero::Type)) +function thunk_generator(world::UInt, source::Union{Method,LineNumberNode}, @nospecialize(FA::Type), @nospecialize(A::Type), @nospecialize(TT::Type), Mode::Enzyme.API.CDerivativeMode, Width::Int, @nospecialize(ModifiedBetween::(NTuple{N,Bool} where {N})), ReturnPrimal::Bool, ShadowInit::Bool, @nospecialize(ABI::Type), ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, @nospecialize(self), @nospecialize(fakeworld), @nospecialize(fa::Type), @nospecialize(a::Type), @nospecialize(tt::Type), @nospecialize(mode::Type), @nospecialize(width::Type), @nospecialize(modifiedbetween::Type), @nospecialize(returnprimal::Type), @nospecialize(shadowinit::Type), @nospecialize(abi::Type), @nospecialize(erriffuncwritten::Type), @nospecialize(runtimeactivity::Type), @nospecialize(strongzero::Type)) @nospecialize - - slotnames = Core.svec(Symbol("#self#"), - :fakeworld, :fa, :a, :tt, :mode, :width, - :modifiedbetween, :returnprimal, :shadowinit, - :abi, :erriffuncwritten, :runtimeactivity, :strongzero) + + slotnames = Core.svec(Symbol("#self#"), + :fakeworld, :fa, :a, :tt, :mode, :width, + :modifiedbetween, :returnprimal, :shadowinit, + :abi, :erriffuncwritten, :runtimeactivity, :strongzero) stub = Core.GeneratedFunctionStub(thunk, slotnames, Core.svec()) ft = eltype(FA) primal_tt = Tuple{map(eltype, TT.parameters)...} # look up the method match - + min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) - + mi = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, ft, primal_tt, world, min_world, max_world) - + mi === nothing && return stub(world, source, :(throw(MethodError($ft, $primal_tt, $world)))) - + check_activity_cache_invalidations(world) edges = Any[] add_edge!(edges, mi) - + ts_ctx = JuliaContext() ctx = context(ts_ctx) activate(ctx) @@ -6296,23 +6289,23 @@ function thunk_generator(world::UInt, source::Union{Method, LineNumberNode}, @no if Mode == API.DEM_ForwardMode - fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} + fwd_sig = Tuple{typeof(EnzymeRules.forward),<:EnzymeRules.FwdConfig,<:Enzyme.EnzymeCore.Annotation,Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} add_edge!(edges, fwd_sig) else - rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} + rev_sig = Tuple{typeof(EnzymeRules.augmented_primal),<:EnzymeRules.RevConfig,<:Enzyme.EnzymeCore.Annotation,Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} add_edge!(edges, rev_sig) - - rev_sig = Tuple{typeof(EnzymeRules.reverse), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Union{Type{<:Enzyme.EnzymeCore.Annotation}, Enzyme.EnzymeCore.Active}, Any, Vararg{Enzyme.EnzymeCore.Annotation}} + + rev_sig = Tuple{typeof(EnzymeRules.reverse),<:EnzymeRules.RevConfig,<:Enzyme.EnzymeCore.Annotation,Union{Type{<:Enzyme.EnzymeCore.Annotation},Enzyme.EnzymeCore.Active},Any,Vararg{Enzyme.EnzymeCore.Annotation}} add_edge!(edges, rev_sig) end - - ina_sig = Tuple{typeof(EnzymeRules.inactive), Vararg{Any}} + + ina_sig = Tuple{typeof(EnzymeRules.inactive),Vararg{Any}} add_edge!(edges, ina_sig) - + for gen_sig in ( - Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Any}}, - Tuple{typeof(EnzymeRules.noalias), Vararg{Any}}, - Tuple{typeof(EnzymeRules.inactive_type), Type}, + Tuple{typeof(EnzymeRules.inactive_noinl),Vararg{Any}}, + Tuple{typeof(EnzymeRules.noalias),Vararg{Any}}, + Tuple{typeof(EnzymeRules.inactive_type),Type}, ) add_edge!(edges, gen_sig) end @@ -6357,27 +6350,27 @@ import GPUCompiler: deferred_codegen_jobs function deferred_id_codegen end -function deferred_id_generator(world::UInt, source::Union{Method, LineNumberNode}, @nospecialize(FA::Type), @nospecialize(A::Type), @nospecialize(TT::Type), Mode::Enzyme.API.CDerivativeMode, Width::Int, @nospecialize(ModifiedBetween::(NTuple{N, Bool} where N)), ReturnPrimal::Bool, ShadowInit::Bool, @nospecialize(ExpectedTapeType::Type), ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, @nospecialize(self), @nospecialize(fa::Type), @nospecialize(a::Type), @nospecialize(tt::Type), @nospecialize(mode::Type), @nospecialize(width::Type), @nospecialize(modifiedbetween::Type), @nospecialize(returnprimal::Type), @nospecialize(shadowinit::Type), @nospecialize(expectedtapetype::Type), @nospecialize(erriffuncwritten::Type), @nospecialize(runtimeactivity::Type), @nospecialize(strongzero::Type)) +function deferred_id_generator(world::UInt, source::Union{Method,LineNumberNode}, @nospecialize(FA::Type), @nospecialize(A::Type), @nospecialize(TT::Type), Mode::Enzyme.API.CDerivativeMode, Width::Int, @nospecialize(ModifiedBetween::(NTuple{N,Bool} where {N})), ReturnPrimal::Bool, ShadowInit::Bool, @nospecialize(ExpectedTapeType::Type), ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, @nospecialize(self), @nospecialize(fa::Type), @nospecialize(a::Type), @nospecialize(tt::Type), @nospecialize(mode::Type), @nospecialize(width::Type), @nospecialize(modifiedbetween::Type), @nospecialize(returnprimal::Type), @nospecialize(shadowinit::Type), @nospecialize(expectedtapetype::Type), @nospecialize(erriffuncwritten::Type), @nospecialize(runtimeactivity::Type), @nospecialize(strongzero::Type)) @nospecialize - + slotnames = Core.svec(Symbol("#self#"), - :fa, :a, :tt, :mode, :width, :modifiedbetween, - :returnprimal, :shadowinit, :expectedtapetype, - :erriffuncwritten, :runtimeactivity, :strongzero) + :fa, :a, :tt, :mode, :width, :modifiedbetween, + :returnprimal, :shadowinit, :expectedtapetype, + :erriffuncwritten, :runtimeactivity, :strongzero) stub = Core.GeneratedFunctionStub(deferred_id_generator, slotnames, Core.svec()) ft = eltype(FA) primal_tt = Tuple{map(eltype, TT.parameters)...} # look up the method match - + min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) - + mi = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, ft, primal_tt, world, min_world, max_world) - + mi === nothing && return stub(world, source, :(throw(MethodError($ft, $primal_tt, $world)))) - + target = EnzymeTarget() rt2 = if A isa UnionAll rrt = primal_return_type_world(Mode == API.DEM_ForwardMode ? Forward : Reverse, world, mi) @@ -6417,7 +6410,7 @@ function deferred_id_generator(world::UInt, source::Union{Method, LineNumberNode StrongZero ) #=abiwrap=# job = - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false), world) + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), world) addr = get_trampoline(job) id = Base.reinterpret(Int, pointer(addr)) From 20feb57c0650d85f4047679e540e8427a5317d2a Mon Sep 17 00:00:00 2001 From: Yousof Mardoukhi Date: Fri, 24 Oct 2025 23:59:53 +0200 Subject: [PATCH 2/8] test: added test for capturing `Called function is not the same type as the call!`. --- test/Project.toml | 3 ++ test/embedded_bitcode.jl | 97 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+) create mode 100644 test/embedded_bitcode.jl diff --git a/test/Project.toml b/test/Project.toml index 403a2464f4..9e6b6395ac 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Clang_jll = "0ee61d77-7f21-5576-8119-9fcc46b10100" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" @@ -12,6 +13,7 @@ InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c" +Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc" @@ -29,5 +31,6 @@ Enzyme = {path = ".."} EnzymeTestUtils = {path = "../lib/EnzymeTestUtils"} [compat] +Clang_jll = "16.0.6" EnzymeTestUtils = "0.2.1" ParallelTestRunner = "1.0.1" diff --git a/test/embedded_bitcode.jl b/test/embedded_bitcode.jl new file mode 100644 index 0000000000..107aad221d --- /dev/null +++ b/test/embedded_bitcode.jl @@ -0,0 +1,97 @@ +using Enzyme +using Clang_jll +using Libdl +using Test + +const FUNC_LLVM_IR = """ + declare double @llvm.rint.f64(double) #1 + + define i32 @func(double* noalias nocapture writeonly %retptr, { i8*, i32, i8*, i8*, i32 }** noalias nocapture readnone %excinfo, double %arg.t, i8* nocapture readnone %arg.arr.0, i8* nocapture readnone %arg.arr.1, i64 %arg.arr.2, i64 %arg.arr.3, double* %arg.arr.4, i64 %arg.arr.5.0, i64 %arg.arr.6.0) local_unnamed_addr #0 { + common.ret: + %.27 = fdiv double %arg.t, 1.000000e-02 + %.28 = tail call double @llvm.rint.f64(double %.27) + %.29 = fptosi double %.28 to i64 + %.42 = icmp slt i64 %.29, 0 + %.43 = select i1 %.42, i64 %arg.arr.5.0, i64 0 + %.44 = add i64 %.43, %.29 + %.55 = mul i64 %.44, %arg.arr.6.0 + %.56 = ptrtoint double* %arg.arr.4 to i64 + %.57 = add i64 %.55, %.56 + %.58 = inttoptr i64 %.57 to double* + %.59 = load double, double* %.58, align 8 + store double %.59, double* %retptr, align 8 + ret i32 0 + } + + define double @func_wrap({ i8*, i32, i8*, i8*, i32 }** %excinfo, double %arg.t, i8* %arg.arr.0, i8* %arg.arr.1, i64 %arg.arr.2, i64 %arg.arr.3, double* %arg.arr.4, i64 %arg.arr.5.0, i64 %arg.arr.6.0) { + entry: + %tmp = alloca double, align 8 + %st = call i32 @func(double* %tmp, { i8*, i32, i8*, i8*, i32 }** %excinfo, double %arg.t, i8* %arg.arr.0, i8* %arg.arr.1, i64 %arg.arr.2, i64 %arg.arr.3, double* %arg.arr.4, i64 %arg.arr.5.0, i64 %arg.arr.6.0) + %val = load double, double* %tmp, align 8 + ret double %val + } + + + attributes #0 = { mustprogress nofree nosync nounwind willreturn } + attributes #1 = { mustprogress nocallback nofree nosync nounwind readnone speculatable willreturn } + attributes #2 = { noinline } +""" + + +tmp_dir = tempdir() +tmp_so_file = joinpath(tmp_dir, "func.so") +run( + pipeline( + `$(clang()) -x ir - -Xclang -no-opaque-pointers -O3 -fPIC -fembed-bitcode -shared -o $(tmp_so_file)`; + stdin=IOBuffer(FUNC_LLVM_IR) + ) +) + +lib = Libdl.dlopen(tmp_so_file) +const fptr = Libdl.dlsym(lib, :func_wrap) + + +function func_ccall(t::Float64, arr::AbstractVector{Float64}) + nitems = length(arr) + bitsize = Base.elsize(arr) + GC.@preserve arr begin + excinfo = Ptr{Ptr{Cvoid}}(C_NULL) + base::Ptr{Cdouble} = pointer(arr) + + ccall(fptr, Cdouble, + (Ptr{Ptr{Cvoid}}, Cdouble, Ptr{Cvoid}, Ptr{Cvoid}, + Clong, Clong, Ptr{Cdouble}, Clong, Clong), + excinfo, t, C_NULL, C_NULL, nitems, bitsize, + base, nitems, nitems * bitsize) + end +end + +@testset "Broken Function ccall + @view" begin + a = rand(10) + expected_grad_a = (nothing, [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + grad_a = gradient(Reverse, func_ccall, Const(0.0), a) + @test expected_grad_a == grad_a + + + errstream = joinpath(tempdir(), "stdout.txt") + err_llvmir = nothing + b = @view a[1:5] + + @show errstream + + redirect_stdio(stdout=errstream, stderr=errstream, stdin=devnull) do + try + gradient(Reverse, func_ccall, Const(0.0), b) + catch e + err_llvmir = e + # finally + # redirect_stdout(old_stdout) + end + + @test err_llvmir !== nothing + @test occursin("Broken function", err_llvmir.info) + end + + errtxt = read(errstream, String) + @test occursin("Called function is not the same type as the call!", errtxt) +end From 8001a7fcc4a1e78bc10bddb03a1a0d947ba80134 Mon Sep 17 00:00:00 2001 From: Yousof Mardoukhi Date: Sat, 25 Oct 2025 00:43:57 +0200 Subject: [PATCH 3/8] refactor: undone unnecessary formatting. --- src/compiler.jl | 1473 +++++++++++++++++++++++------------------------ 1 file changed, 736 insertions(+), 737 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7d918cae0c..969e803dab 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -108,9 +108,9 @@ struct PrimalCompilerParams <: AbstractEnzymeCompilerParams end function EnzymeCompilerParams(TT, mode, width, rt, run_enzyme, abiwrap, - modifiedBetween, returnPrimal, shadowInit, - expectedTapeType, ABI, - err_if_func_written, runtimeActivity, strongZero) + modifiedBetween, returnPrimal, shadowInit, + expectedTapeType, ABI, + err_if_func_written, runtimeActivity, strongZero) params = PrimalCompilerParams(mode) EnzymeCompilerParams( params, @@ -132,7 +132,7 @@ function EnzymeCompilerParams(TT, mode, width, rt, run_enzyme, abiwrap, end DefaultCompilerTarget(; kwargs...) = - GPUCompiler.NativeCompilerTarget(; jlruntime=true, kwargs...) + GPUCompiler.NativeCompilerTarget(; jlruntime = true, kwargs...) # TODO: Audit uses function EnzymeTarget() @@ -157,15 +157,15 @@ if VERSION >= v"1.11.0-DEV.1552" always_inline::Any method_table::Core.MethodTable param_type::Type - last_fwd_rule_world::Union{Nothing,Tuple} - last_rev_rule_world::Union{Nothing,Tuple} - last_ina_rule_world::Union{Nothing,Tuple} + last_fwd_rule_world::Union{Nothing, Tuple} + last_rev_rule_world::Union{Nothing, Tuple} + last_ina_rule_world::Union{Nothing, Tuple} end @inline EnzymeCacheToken(target_type::Type, always_inline::Any, method_table::Core.MethodTable, param_type::Type, world::UInt, is_forward::Bool, is_reverse::Bool, inactive_rule::Bool) = EnzymeCacheToken(target_type, always_inline, method_table, param_type, - is_forward ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.forward, Tuple{<:EnzymeCore.EnzymeRules.FwdConfig,<:Annotation,Type{<:Annotation},Vararg{Annotation}}, world)...,) : nothing, - is_reverse ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.augmented_primal, Tuple{<:EnzymeCore.EnzymeRules.RevConfig,<:Annotation,Type{<:Annotation},Vararg{Annotation}}, world)...,) : nothing, + is_forward ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.forward, Tuple{<:EnzymeCore.EnzymeRules.FwdConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world)...,) : nothing, + is_reverse ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.augmented_primal, Tuple{<:EnzymeCore.EnzymeRules.RevConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world)...,) : nothing, inactive_rule ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.inactive, Tuple{Vararg{Any}}, world)...,) : nothing ) @@ -279,7 +279,7 @@ const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,Data if (T isa Type) T = T::Type legal = T ∈ Tys - + if legal if name == :ldexp if !(sparam_vals[2] <: Integer) @@ -318,7 +318,7 @@ const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,Data if (T isa Type) T = T::Type legal = T ∈ Tys - + if legal if !all(==(T), sparam_vals) legal = false @@ -532,11 +532,11 @@ include("typeutils/inference.jl") import .Interpreter: isKWCallSignature -const mod_to_edges = Dict{LLVM.Module,Vector{Any}}() +const mod_to_edges = Dict{LLVM.Module, Vector{Any}}() mutable struct HandlerState - primalf::Union{Nothing,LLVM.Function} + primalf::Union{Nothing, LLVM.Function} must_wrap::Bool - actualRetType::Union{Nothing,Type} + actualRetType::Union{Nothing, Type} lowerConvention::Bool loweredArgs::Set{Int} boxedArgs::Set{Int} @@ -544,7 +544,7 @@ mutable struct HandlerState end -function handleCustom(state::HandlerState, custom, k_name::String, llvmfn::LLVM.Function, name::String, attrs::Vector{LLVM.Attribute}=LLVM.Attribute[], setlink::Bool=true, noinl::Bool=true) +function handleCustom(state::HandlerState, custom, k_name::String, llvmfn::LLVM.Function, name::String, attrs::Vector{LLVM.Attribute} = LLVM.Attribute[], setlink::Bool = true, noinl::Bool = true) attributes = function_attributes(llvmfn) custom[k_name] = linkage(llvmfn) if setlink @@ -561,7 +561,7 @@ function handleCustom(state::HandlerState, custom, k_name::String, llvmfn::LLVM. nothing end -function handle_compiled(state::HandlerState, edges::Vector, run_enzyme::Bool, mode::API.CDerivativeMode, world::UInt, method_table, custom::Dict{String,LLVM.API.LLVMLinkage}, mod::LLVM.Module, mi::Core.MethodInstance, k_name::String, @nospecialize(rettype::Type))::Nothing +function handle_compiled(state::HandlerState, edges::Vector, run_enzyme::Bool, mode::API.CDerivativeMode, world::UInt, method_table, custom::Dict{String, LLVM.API.LLVMLinkage}, mod::LLVM.Module, mi::Core.MethodInstance, k_name::String, @nospecialize(rettype::Type))::Nothing has_custom_rule = false specTypes = Interpreter.simplify_kw(mi.specTypes) @@ -610,13 +610,13 @@ function handle_compiled(state::HandlerState, edges::Vector, run_enzyme::Bool, m func = mi.specTypes.parameters[1] - @static if VERSION < v"1.11-" - else - if func == typeof(Core.memoryref) - attributes = function_attributes(llvmfn) - push!(attributes, EnumAttribute("alwaysinline", 0)) - end +@static if VERSION < v"1.11-" +else + if func == typeof(Core.memoryref) + attributes = function_attributes(llvmfn) + push!(attributes, EnumAttribute("alwaysinline", 0)) end +end meth = mi.def name = meth.name @@ -1024,23 +1024,23 @@ function handle_compiled(state::HandlerState, edges::Vector, run_enzyme::Bool, m attrs = if LLVM.version().major <= 15 LLVM.Attribute[LLVM.EnumAttribute("readnone"), StringAttribute("enzyme_shouldrecompute"), - EnumAttribute("willreturn"), - EnumAttribute("nosync"), - EnumAttribute("nounwind"), - EnumAttribute("nofree"), - ] + EnumAttribute("willreturn"), + EnumAttribute("nosync"), + EnumAttribute("nounwind"), + EnumAttribute("nofree"), + ] else LLVM.Attribute[EnumAttribute("memory", NoEffects.data), StringAttribute("enzyme_shouldrecompute"), - EnumAttribute("willreturn"), - EnumAttribute("nosync"), - EnumAttribute("nounwind"), - EnumAttribute("nofree")] + EnumAttribute("willreturn"), + EnumAttribute("nosync"), + EnumAttribute("nounwind"), + EnumAttribute("nofree")] end handleCustom(state, custom, k_name, llvmfn, name, attrs) return end -function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing,LLVM.Function}, job, edges, run_enzyme, mode::API.CDerivativeMode) +function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLVM.Function}, job, edges, run_enzyme, mode::API.CDerivativeMode) for f in functions(mod) mi, RT = enzyme_custom_extract_mi(f, false) @@ -1184,12 +1184,12 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing,LLVM state = HandlerState( primalf, - false, #=mustwrap=# - nothing, #=actualRetType=# - true, #=lowerConvention=# - Set{Int}(), #=loweredArgs=# - Set{Int}(), #=boxedArgs=# - Tuple{Symbol,Type}[], #=fnsToInject=# + #=mustwrap=#false, + #=actualRetType=#nothing, + #=lowerConvention=#true, + #=loweredArgs=#Set{Int}(), + #=boxedArgs=#Set{Int}(), + #=fnsToInject=#Tuple{Symbol,Type}[], ) for fname in LLVM.name.(functions(mod)) @@ -1240,11 +1240,11 @@ function nested_codegen!( target = DefaultCompilerTarget() params = PrimalCompilerParams(mode) - job = CompilerJob(funcspec, CompilerConfig(target, params; kernel=false, libraries=true, toplevel=true, optimize=false, cleanup=false, only_entry=false, validate=false), world) + job = CompilerJob(funcspec, CompilerConfig(target, params; kernel = false, libraries = true, toplevel = true, optimize = false, cleanup = false, only_entry = false, validate = false), world) GPUCompiler.prepare_job!(job) otherMod, meta = GPUCompiler.emit_llvm(job) - + interp = GPUCompiler.get_interpreter(job) prepare_llvm(interp, otherMod, job, meta) @@ -1263,15 +1263,15 @@ function nested_codegen!( API.AddPreserveNVVMPass!(pm, true) #=Begin=# LLVM.run!(pm, otherMod) end - + if DumpPreNestedCheck[] - API.EnzymeDumpModuleRef(otherMod.ref) + API.EnzymeDumpModuleRef(otherMod.ref) end check_ir(interp, job, otherMod) - + if DumpPreNestedOpt[] - API.EnzymeDumpModuleRef(otherMod.ref) + API.EnzymeDumpModuleRef(otherMod.ref) end # Skipped inline of blas @@ -1281,11 +1281,11 @@ function nested_codegen!( # Apply first stage of optimization's so that this module is at the same stage as `mod` optimize!(otherMod, JIT.get_tm()) - + if DumpPostNestedOpt[] - API.EnzymeDumpModuleRef(otherMod.ref) + API.EnzymeDumpModuleRef(otherMod.ref) end - + # 4) Link the corresponding module LLVM.link!(mod, otherMod) # 5) Call the function @@ -1426,7 +1426,7 @@ function julia_post_cache_store( end p = pn - vals = get_julia_inner_types(B, p, v, added=added) + vals = get_julia_inner_types(B, p, v, added = added) r = emit_writebarrier!(B, vals) @assert isa(r, LLVM.Instruction) push!(added, r.ref) @@ -1490,29 +1490,29 @@ function julia_undef_value_for_type( end # If count is nothing, it represents that we have an allocation of one of `Ty`. If it is a tuple LLVM values, it represents {the total size in bytes, the aligned size of each element} -function create_recursive_stores(B::LLVM.IRBuilder, @nospecialize(Ty::DataType), @nospecialize(prev::LLVM.Value), @nospecialize(count::Union{Nothing,Tuple{LLVM.Value,LLVM.ConstantInt}}))::Nothing +function create_recursive_stores(B::LLVM.IRBuilder, @nospecialize(Ty::DataType), @nospecialize(prev::LLVM.Value), @nospecialize(count::Union{Nothing, Tuple{LLVM.Value, LLVM.ConstantInt}}))::Nothing if Base.datatype_pointerfree(Ty) return end isboxed_ref = Ref{Bool}() LLVMType = LLVM.LLVMType(ccall(:jl_type_to_llvm, LLVM.API.LLVMTypeRef, - (Any, LLVM.Context, Ptr{Bool}), Ty, LLVM.context(), isboxed_ref)) + (Any, LLVM.Context, Ptr{Bool}), Ty, LLVM.context(), isboxed_ref)) if !isboxed_ref[] zeroAll = false prev = bitcast!(B, prev, LLVM.PointerType(LLVMType, addrspace(value_type(prev)))) prev = addrspacecast!(B, prev, LLVM.PointerType(LLVMType, Derived)) - atomic = true - if count === nothing - T_int64 = LLVM.Int64Type() + atomic = true + if count === nothing + T_int64 = LLVM.Int64Type() zero_single_allocation(B, Ty, LLVMType, prev, zeroAll, LLVM.ConstantInt(T_int64, 0); atomic) - nothing - else - (Size, AlignedSize) = count - zero_allocation(B, Ty, LLVMType, prev, AlignedSize, Size, zeroAll, atomic) - nothing - end + nothing + else + (Size, AlignedSize) = count + zero_allocation(B, Ty, LLVMType, prev, AlignedSize, Size, zeroAll, atomic) + nothing + end else if fieldcount(Ty) == 0 error("Error handling recursive stores for $Ty which has a fieldcount of 0") @@ -1523,64 +1523,64 @@ function create_recursive_stores(B::LLVM.IRBuilder, @nospecialize(Ty::DataType), T_int8 = LLVM.Int8Type() T_int64 = LLVM.Int64Type() - + T_pint8 = LLVM.PointerType(T_int8) prev2 = bitcast!(B, prev, LLVM.PointerType(T_int8, addrspace(value_type(prev)))) typedesc = Base.DataTypeFieldDesc(Ty) - needs_fullzero = false - if count !== nothing - for i in 1:fieldcount(Ty) - Ty2 = fieldtype(Ty, i) - off = fieldoffset(Ty, i) - - if typedesc[i].isptr || !(off == 0 && Base.aligned_sizeof(Ty) == Base.aligned_sizeof(Ty2)) - needs_fullzero = true - break - end - end - end - - if needs_fullzero - zeroAll = false - prev = bitcast!(B, prev, LLVM.PointerType(LLVMType, addrspace(value_type(prev)))) - prev = addrspacecast!(B, prev, LLVM.PointerType(LLVMType, Derived)) - atomic = true - (Size, AlignedSize) = count - zero_allocation(B, Ty, LLVMType, prev, AlignedSize, Size, zeroAll, atomic) - nothing - else - for i in 1:fieldcount(Ty) - Ty2 = fieldtype(Ty, i) - off = fieldoffset(Ty, i) - - prev3 = inbounds_gep!( - B, - T_int8, - prev2, - LLVM.Value[LLVM.ConstantInt(Int64(off))], - ) - - if typedesc[i].isptr - @assert count === nothing - Ty2 = Any - zeroAll = false - prev3 = bitcast!(B, prev3, LLVM.PointerType(T_prjlvalue, addrspace(value_type(prev3)))) - if addrspace(value_type(prev3)) != Derived - prev3 = addrspacecast!(B, prev3, LLVM.PointerType(T_prjlvalue, Derived)) - end - zero_single_allocation(B, Ty2, T_prjlvalue, prev3, zeroAll, LLVM.ConstantInt(T_int64, 0); atomic=true) - else - if count !== nothing - @assert off == 0 - @assert Base.aligned_sizeof(Ty) == Base.aligned_sizeof(Ty2) - end - create_recursive_stores(B, Ty2, prev3, count) - end - end - nothing - end + needs_fullzero = false + if count !== nothing + for i in 1:fieldcount(Ty) + Ty2 = fieldtype(Ty, i) + off = fieldoffset(Ty, i) + + if typedesc[i].isptr || !(off == 0 && Base.aligned_sizeof(Ty) == Base.aligned_sizeof(Ty2)) + needs_fullzero = true + break + end + end + end + + if needs_fullzero + zeroAll = false + prev = bitcast!(B, prev, LLVM.PointerType(LLVMType, addrspace(value_type(prev)))) + prev = addrspacecast!(B, prev, LLVM.PointerType(LLVMType, Derived)) + atomic = true + (Size, AlignedSize) = count + zero_allocation(B, Ty, LLVMType, prev, AlignedSize, Size, zeroAll, atomic) + nothing + else + for i in 1:fieldcount(Ty) + Ty2 = fieldtype(Ty, i) + off = fieldoffset(Ty, i) + + prev3 = inbounds_gep!( + B, + T_int8, + prev2, + LLVM.Value[LLVM.ConstantInt(Int64(off))], + ) + + if typedesc[i].isptr + @assert count === nothing + Ty2 = Any + zeroAll = false + prev3 = bitcast!(B, prev3, LLVM.PointerType(T_prjlvalue, addrspace(value_type(prev3)))) + if addrspace(value_type(prev3)) != Derived + prev3 = addrspacecast!(B, prev3, LLVM.PointerType(T_prjlvalue, Derived)) + end + zero_single_allocation(B, Ty2, T_prjlvalue, prev3, zeroAll, LLVM.ConstantInt(T_int64, 0); atomic=true) + else + if count !== nothing + @assert off == 0 + @assert Base.aligned_sizeof(Ty) == Base.aligned_sizeof(Ty2) + end + create_recursive_stores(B, Ty2, prev3, count) + end + end + nothing + end end end @@ -1594,62 +1594,62 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie count = nothing if !has arg = V - if isa(arg, LLVM.CallInst) - fn = LLVM.called_operand(arg) - nm = "" - if isa(fn, LLVM.Function) - nm = LLVM.name(fn) - end - - # Type tag is arg 3 - if nm == "julia.gc_alloc_obj" || - nm == "jl_gc_alloc_typed" || - nm == "ijl_gc_alloc_typed" - totalsize = operands(arg)[2] - - @assert value_type(totalsize) isa LLVM.IntegerType - - arg = operands(arg)[3] - - if isa(arg, LLVM.CallInst) - fn = LLVM.called_operand(arg) - nm = "" - if isa(fn, LLVM.Function) - nm = LLVM.name(fn) - end - if LLVM.callconv(arg) == 37 || nm == "julia.call" - index = 1 - if LLVM.callconv(arg) != 37 - fn = first(operands(arg)) - nm = LLVM.name(fn) - index += 1 - end - if nm == "jl_f_apply_type" || nm == "ijl_f_apply_type" - index += 1 - found = Any[] - legal, Ty = absint(operands(arg)[index], partial) - if legal && Ty == NTuple - legal, Ty = absint(operands(arg)[index+2]) - if legal - # count should represent {the total size in bytes, the aligned size of each element} - B = LLVM.IRBuilder() - position!(B, V) - alignsize = LLVM.ConstantInt(value_type(totalsize), Base.aligned_sizeof(Ty)) - count = (totalsize, alignsize) - has = true - end - end - end - end - end - end - end - - - if !has + if isa(arg, LLVM.CallInst) + fn = LLVM.called_operand(arg) + nm = "" + if isa(fn, LLVM.Function) + nm = LLVM.name(fn) + end + + # Type tag is arg 3 + if nm == "julia.gc_alloc_obj" || + nm == "jl_gc_alloc_typed" || + nm == "ijl_gc_alloc_typed" + totalsize = operands(arg)[2] + + @assert value_type(totalsize) isa LLVM.IntegerType + + arg = operands(arg)[3] + + if isa(arg, LLVM.CallInst) + fn = LLVM.called_operand(arg) + nm = "" + if isa(fn, LLVM.Function) + nm = LLVM.name(fn) + end + if LLVM.callconv(arg) == 37 || nm == "julia.call" + index = 1 + if LLVM.callconv(arg) != 37 + fn = first(operands(arg)) + nm = LLVM.name(fn) + index += 1 + end + if nm == "jl_f_apply_type" || nm == "ijl_f_apply_type" + index += 1 + found = Any[] + legal, Ty = absint(operands(arg)[index], partial) + if legal && Ty == NTuple + legal, Ty = absint(operands(arg)[index+2]) + if legal + # count should represent {the total size in bytes, the aligned size of each element} + B = LLVM.IRBuilder() + position!(B, V) + alignsize = LLVM.ConstantInt(value_type(totalsize), Base.aligned_sizeof(Ty)) + count = (totalsize, alignsize) + has = true + end + end + end + end + end + end + end + + + if !has fn = LLVM.parent(LLVM.parent(V)) - throw(AssertionError("$(string(fn))\n Allocation could not have its type statically determined $(string(V))")) - end + throw(AssertionError("$(string(fn))\n Allocation could not have its type statically determined $(string(V))")) + end end if mode == API.DEM_ReverseModePrimal || @@ -1663,9 +1663,9 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie operands(V)[3] = unsafe_to_llvm(B, Base.RefValue{Ty}) end end - + if Base.datatype_pointerfree(Ty) - return + return end if mode == API.DEM_ForwardMode && (used || idx != 0) @@ -1687,7 +1687,7 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie B = LLVM.IRBuilder() position!(B, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(prev))) - create_recursive_stores(B, Ty, prev, count) + create_recursive_stores(B, Ty, prev, count) end if (mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeCombined) && used # Zero any jlvalue_t inner elements of preceeding allocation. @@ -1711,8 +1711,8 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie # Julia could decide to dead store eliminate the memset (not being read before the store of jlvaluet'), resulting in an error B = LLVM.IRBuilder() position!(B, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(V))) - - create_recursive_stores(B, Ty, V, count) + + create_recursive_stores(B, Ty, V, count) end nothing @@ -1782,7 +1782,7 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D LLVMType, jlType, )] - + addedvals = LLVM.Value[] while length(todo) != 0 path, ty, jlty = popfirst!(todo) @@ -1822,9 +1822,9 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D typed_fieldtype(jlty, i) elseif !(jlty isa DataType) if eltype(ty) isa LLVM.PointerType && LLVM.addrspace(eltype(ty)) == 10 - Any + Any else - throw(AssertionError("jlty=$jlty ty=$ty")) + throw(AssertionError("jlty=$jlty ty=$ty")) end end npath = copy(path) @@ -1834,7 +1834,7 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D continue end if isa(ty, LLVM.VectorType) - @assert jlty isa DataType + @assert jlty isa DataType for i = 1:size(ty) npath = copy(path) push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i - 1)) @@ -1900,7 +1900,7 @@ function zero_allocation( name = "zeroType." * string(jlType) if atomic - name = name * ".atomic" + name = name * ".atomic" end wrapper_f = LLVM.Function( @@ -1994,7 +1994,7 @@ function julia_allocator(B::LLVM.IRBuilder, @nospecialize(LLVMType::LLVM.LLVMTyp TT = Compiler.tape_type(LLVMType) if esizeof(TT) != convert(Int, AlignedSize) GPUCompiler.@safe_error "Enzyme aligned size and Julia size disagree" AlignedSize = - convert(Int, AlignedSize) esizeof(TT) fieldtypes(TT) LLVMType = strip(string(LLVMType)) + convert(Int, AlignedSize) esizeof(TT) fieldtypes(TT) LLVMType=strip(string(LLVMType)) emit_error(B, nothing, "Enzyme: Tape allocation failed.") # TODO: Pick appropriate orig return LLVM.API.LLVMValueRef(LLVM.UndefValue(LLVMType).ref) end @@ -2133,7 +2133,7 @@ function emit_inacterror(B::LLVM.API.LLVMBuilderRef, V::LLVM.API.LLVMValueRef, o funcT = LLVM.FunctionType( LLVM.VoidType(), LLVMType[LLVM.PointerType(LLVM.Int8Type())], - vararg=true, + vararg = true, ) func, _ = get_function!(mod, "jl_errorf", funcT, LLVM.Attribute[EnumAttribute("noreturn")]) @@ -2147,7 +2147,7 @@ include("rules/llvmrules.jl") function add_one_in_place(x) if x isa Base.RefValue x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x)))) - elseif x isa (Array{T,0} where {T}) + elseif x isa (Array{T,0} where T) x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x)))) else throw(EnzymeNonScalarReturnException(x, "")) @@ -2325,7 +2325,7 @@ function enzyme_extract_world(fn::LLVM.Function)::UInt throw(AssertionError("Enzyme: could not find world in $(string(fn))")) end -function enzyme_custom_extract_mi(orig::LLVM.CallInst, error::Bool=true) +function enzyme_custom_extract_mi(orig::LLVM.CallInst, error::Bool = true) operand = LLVM.called_operand(orig) if isa(operand, LLVM.Function) return enzyme_custom_extract_mi(operand::LLVM.Function, error) @@ -2335,7 +2335,7 @@ function enzyme_custom_extract_mi(orig::LLVM.CallInst, error::Bool=true) return nothing, nothing end -function enzyme_custom_extract_mi(orig::LLVM.Function, error::Bool=true) +function enzyme_custom_extract_mi(orig::LLVM.Function, error::Bool = true) mi = nothing RT = nothing for fattr in collect(function_attributes(orig)) @@ -2356,7 +2356,7 @@ function enzyme_custom_extract_mi(orig::LLVM.Function, error::Bool=true) return mi, RT end -function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error::Bool=true) +function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error::Bool = true) ty = nothing byref = nothing for fattr in collect(parameter_attributes(fn, idx)) @@ -2395,7 +2395,7 @@ function enzyme!( parallel::Bool, @nospecialize(actualRetType::Type), wrap::Bool, - @nospecialize(modifiedBetween::NTuple{N,Bool} where {N}), + @nospecialize(modifiedBetween::NTuple{N, Bool} where N), returnPrimal::Bool, @nospecialize(expectedTapeType::Type), loweredArgs::Set{Int}, @@ -2477,14 +2477,14 @@ function enzyme!( push!(args_known_values, API.IntList()) end if length(uncacheable_args) != length(collect(parameters(primalf))) - msg = sprint() do io - println(io, "length(uncacheable_args) != length(collect(parameters(primalf)))", TT) - println(io, "TT=", TT) - println(io, "modifiedBetween=", modifiedBetween) - println(io, "uncacheable_args=", uncacheable_args) - println(io, "primal", string(primalf)) - end - throw(AssertionError(msg)) + msg = sprint() do io + println(io, "length(uncacheable_args) != length(collect(parameters(primalf)))", TT) + println(io, "TT=", TT) + println(io, "modifiedBetween=", modifiedBetween) + println(io, "uncacheable_args=", uncacheable_args) + println(io, "primal", string(primalf)) + end + throw(AssertionError(msg)) end @assert length(args_typeInfo) == length(collect(parameters(primalf))) @@ -2502,266 +2502,266 @@ function enzyme!( enzyme_context = EnzymeContext() GC.@preserve enzyme_context begin - LLVM.@dispose logic = Logic(enzyme_context) begin + LLVM.@dispose logic = Logic(enzyme_context) begin - TA = TypeAnalysis(logic) + TA = TypeAnalysis(logic) - retTT = if !isa(actualRetType, Union) && - actualRetType <: Tuple && - in(Any, actualRetType.parameters) - TypeTree() - else - typeTree = typetree(actualRetType, ctx, dl, seen) - if !isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType) - typeTree = copy(typeTree) - merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) - only!(typeTree, -1) - end - typeTree - end + retTT = if !isa(actualRetType, Union) && + actualRetType <: Tuple && + in(Any, actualRetType.parameters) + TypeTree() + else + typeTree = typetree(actualRetType, ctx, dl, seen) + if !isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType) + typeTree = copy(typeTree) + merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) + only!(typeTree, -1) + end + typeTree + end - typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values) + typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values) - TapeType = Cvoid + TapeType = Cvoid - if mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient - returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType)) - shadowReturnUsed = - returnUsed && ( - retType == API.DFT_DUP_ARG || - retType == API.DFT_DUP_NONEED || - rt <: MixedDuplicated || - rt <: BatchMixedDuplicated - ) - returnUsed &= returnPrimal - augmented = API.EnzymeCreateAugmentedPrimal( - logic, - primalf, - retType, - args_activity, - TA, - returnUsed, #=returnUsed=# - shadowReturnUsed, #=shadowReturnUsed=# - typeInfo, - uncacheable_args, - false, - runtimeActivity, - strongZero, - width, - parallel, - ) #=atomicAdd=# - - # 2. get new_primalf and tape - augmented_primalf = - LLVM.Function(API.EnzymeExtractFunctionFromAugmentation(augmented)) - tape = API.EnzymeExtractTapeTypeFromAugmentation(augmented) - utape = API.EnzymeExtractUnderlyingTapeTypeFromAugmentation(augmented) - if utape != C_NULL - TapeType = EnzymeTapeToLoad{Compiler.tape_type(LLVMType(utape))} - tape = utape - elseif tape != C_NULL - TapeType = Compiler.tape_type(LLVMType(tape)) - else - TapeType = Cvoid - end - if expectedTapeType !== UnknownTapeType - @assert expectedTapeType === TapeType - end + if mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient + returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType)) + shadowReturnUsed = + returnUsed && ( + retType == API.DFT_DUP_ARG || + retType == API.DFT_DUP_NONEED || + rt <: MixedDuplicated || + rt <: BatchMixedDuplicated + ) + returnUsed &= returnPrimal + augmented = API.EnzymeCreateAugmentedPrimal( + logic, + primalf, + retType, + args_activity, + TA, + returnUsed, #=returnUsed=# + shadowReturnUsed, #=shadowReturnUsed=# + typeInfo, + uncacheable_args, + false, + runtimeActivity, + strongZero, + width, + parallel, + ) #=atomicAdd=# - if wrap - augmented_primalf = create_abi_wrapper( - augmented_primalf, - TT, - rt, - actualRetType, - API.DEM_ReverseModePrimal, - augmented, - width, - returnPrimal, - shadow_init, - world, - interp, - runtimeActivity, - ) - end + # 2. get new_primalf and tape + augmented_primalf = + LLVM.Function(API.EnzymeExtractFunctionFromAugmentation(augmented)) + tape = API.EnzymeExtractTapeTypeFromAugmentation(augmented) + utape = API.EnzymeExtractUnderlyingTapeTypeFromAugmentation(augmented) + if utape != C_NULL + TapeType = EnzymeTapeToLoad{Compiler.tape_type(LLVMType(utape))} + tape = utape + elseif tape != C_NULL + TapeType = Compiler.tape_type(LLVMType(tape)) + else + TapeType = Cvoid + end + if expectedTapeType !== UnknownTapeType + @assert expectedTapeType === TapeType + end + + if wrap + augmented_primalf = create_abi_wrapper( + augmented_primalf, + TT, + rt, + actualRetType, + API.DEM_ReverseModePrimal, + augmented, + width, + returnPrimal, + shadow_init, + world, + interp, + runtimeActivity, + ) + end - # TODOs: - # 1. Handle mutable or !pointerfree arguments by introducing caching - # + specifically by setting uncacheable_args[i] = true - - adjointf = LLVM.Function( - API.EnzymeCreatePrimalAndGradient( - logic, - primalf, - retType, - args_activity, - TA, - false, - false, - API.DEM_ReverseModeGradient, - runtimeActivity, - strongZero, - width, #=mode=# - tape, - false, - typeInfo, #=forceAnonymousTape=# - uncacheable_args, - augmented, - parallel, - ), - ) #=atomicAdd=# - if wrap - adjointf = create_abi_wrapper( - adjointf, - TT, - rt, - actualRetType, - API.DEM_ReverseModeGradient, - augmented, - width, - false, - shadow_init, - world, - interp, - runtimeActivity - ) #=returnPrimal=# - end - elseif mode == API.DEM_ReverseModeCombined - returnUsed = !isghostty(actualRetType) - returnUsed &= returnPrimal - adjointf = LLVM.Function( - API.EnzymeCreatePrimalAndGradient( - logic, - primalf, - retType, - args_activity, - TA, - returnUsed, - false, - API.DEM_ReverseModeCombined, - runtimeActivity, - strongZero, - width, #=mode=# - C_NULL, - false, - typeInfo, #=forceAnonymousTape=# - uncacheable_args, - C_NULL, - parallel, - ), - ) #=atomicAdd=# - augmented_primalf = nothing - if wrap - adjointf = create_abi_wrapper( - adjointf, - TT, - rt, - actualRetType, - API.DEM_ReverseModeCombined, - nothing, - width, - returnPrimal, - shadow_init, - world, - interp, - runtimeActivity - ) - end - elseif mode == API.DEM_ForwardMode - returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType)) + # TODOs: + # 1. Handle mutable or !pointerfree arguments by introducing caching + # + specifically by setting uncacheable_args[i] = true - literal_rt = eltype(rt) + adjointf = LLVM.Function( + API.EnzymeCreatePrimalAndGradient( + logic, + primalf, + retType, + args_activity, + TA, + false, + false, + API.DEM_ReverseModeGradient, + runtimeActivity, + strongZero, + width, #=mode=# + tape, + false, + typeInfo, #=forceAnonymousTape=# + uncacheable_args, + augmented, + parallel, + ), + ) #=atomicAdd=# + if wrap + adjointf = create_abi_wrapper( + adjointf, + TT, + rt, + actualRetType, + API.DEM_ReverseModeGradient, + augmented, + width, + false, + shadow_init, + world, + interp, + runtimeActivity + ) #=returnPrimal=# + end + elseif mode == API.DEM_ReverseModeCombined + returnUsed = !isghostty(actualRetType) + returnUsed &= returnPrimal + adjointf = LLVM.Function( + API.EnzymeCreatePrimalAndGradient( + logic, + primalf, + retType, + args_activity, + TA, + returnUsed, + false, + API.DEM_ReverseModeCombined, + runtimeActivity, + strongZero, + width, #=mode=# + C_NULL, + false, + typeInfo, #=forceAnonymousTape=# + uncacheable_args, + C_NULL, + parallel, + ), + ) #=atomicAdd=# + augmented_primalf = nothing + if wrap + adjointf = create_abi_wrapper( + adjointf, + TT, + rt, + actualRetType, + API.DEM_ReverseModeCombined, + nothing, + width, + returnPrimal, + shadow_init, + world, + interp, + runtimeActivity + ) + end + elseif mode == API.DEM_ForwardMode + returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType)) - if !isghostty(literal_rt) && runtimeActivity && GPUCompiler.deserves_argbox(actualRetType) && !GPUCompiler.deserves_argbox(literal_rt) - else - returnUsed &= returnPrimal - end + literal_rt = eltype(rt) - adjointf = LLVM.Function( - API.EnzymeCreateForwardDiff( - logic, - primalf, - retType, - args_activity, - TA, - returnUsed, - API.DEM_ForwardMode, - runtimeActivity, - strongZero, - width, #=mode=# - C_NULL, - typeInfo, #=additionalArg=# - uncacheable_args, - ), - ) - augmented_primalf = nothing - if wrap - pf = adjointf - adjointf = create_abi_wrapper( - adjointf, - TT, - rt, - actualRetType, - API.DEM_ForwardMode, - nothing, - width, - returnPrimal, - shadow_init, - world, - interp, - runtimeActivity - ) - end - else - @assert "Unhandled derivative mode", mode - end - if DumpPostWrap[] - API.EnzymeDumpModuleRef(mod.ref) - end + if !isghostty(literal_rt) && runtimeActivity && GPUCompiler.deserves_argbox(actualRetType) && !GPUCompiler.deserves_argbox(literal_rt) + else + returnUsed &= returnPrimal + end + + adjointf = LLVM.Function( + API.EnzymeCreateForwardDiff( + logic, + primalf, + retType, + args_activity, + TA, + returnUsed, + API.DEM_ForwardMode, + runtimeActivity, + strongZero, + width, #=mode=# + C_NULL, + typeInfo, #=additionalArg=# + uncacheable_args, + ), + ) + augmented_primalf = nothing + if wrap + pf = adjointf + adjointf = create_abi_wrapper( + adjointf, + TT, + rt, + actualRetType, + API.DEM_ForwardMode, + nothing, + width, + returnPrimal, + shadow_init, + world, + interp, + runtimeActivity + ) + end + else + @assert "Unhandled derivative mode", mode + end + if DumpPostWrap[] + API.EnzymeDumpModuleRef(mod.ref) + end - # Rewrite enzyme_ignore_derivatives functions to the identity of their first argument. - to_delete = LLVM.Function[] - for fn in functions(mod) - if startswith(name(fn), "__enzyme_ignore_derivatives") - push!(to_delete, fn) - to_delete_inst = LLVM.CallInst[] - for u in LLVM.uses(fn) - ci = LLVM.user(u) - @assert isa(ci, LLVM.CallInst) - LLVM.replace_uses!(ci, operands(ci)[1]) - push!(to_delete_inst, ci) - end - for ci in to_delete_inst - LLVM.erase!(ci) - end - end + # Rewrite enzyme_ignore_derivatives functions to the identity of their first argument. + to_delete = LLVM.Function[] + for fn in functions(mod) + if startswith(name(fn), "__enzyme_ignore_derivatives") + push!(to_delete, fn) + to_delete_inst = LLVM.CallInst[] + for u in LLVM.uses(fn) + ci = LLVM.user(u) + @assert isa(ci, LLVM.CallInst) + LLVM.replace_uses!(ci, operands(ci)[1]) + push!(to_delete_inst, ci) end - for fn in to_delete - LLVM.erase!(fn) + for ci in to_delete_inst + LLVM.erase!(ci) end - LLVM.verify(mod) + end + end + for fn in to_delete + LLVM.erase!(fn) + end + LLVM.verify(mod) - API.EnzymeLogicErasePreprocessedFunctions(logic) - adjointfname = adjointf == nothing ? nothing : LLVM.name(adjointf) - augmented_primalfname = - augmented_primalf == nothing ? nothing : LLVM.name(augmented_primalf) - for f in collect(functions(mod)) - API.EnzymeFixupBatchedJuliaCallingConvention(f) - end - ModulePassManager() do pm - dce!(pm) - LLVM.run!(pm, mod) - end - fix_decayaddr!(mod) - adjointf = adjointf == nothing ? nothing : functions(mod)[adjointfname] - augmented_primalf = - augmented_primalf == nothing ? nothing : functions(mod)[augmented_primalfname] - if DumpPostEnzyme[] - API.EnzymeDumpModuleRef(mod.ref) - end + API.EnzymeLogicErasePreprocessedFunctions(logic) + adjointfname = adjointf == nothing ? nothing : LLVM.name(adjointf) + augmented_primalfname = + augmented_primalf == nothing ? nothing : LLVM.name(augmented_primalf) + for f in collect(functions(mod)) + API.EnzymeFixupBatchedJuliaCallingConvention(f) + end + ModulePassManager() do pm + dce!(pm) + LLVM.run!(pm, mod) + end + fix_decayaddr!(mod) + adjointf = adjointf == nothing ? nothing : functions(mod)[adjointfname] + augmented_primalf = + augmented_primalf == nothing ? nothing : functions(mod)[augmented_primalfname] + if DumpPostEnzyme[] + API.EnzymeDumpModuleRef(mod.ref) + end - return adjointf, augmented_primalf, TapeType - end # @dispose logic + return adjointf, augmented_primalf, TapeType + end # @dispose logic end # GC.preserve enzyme_context end @@ -2886,7 +2886,7 @@ function create_abi_wrapper( if is_adjoint NT = Tuple{ActiveRetTypes...} if any( - any_jltypes(convert(LLVM.LLVMType, b; allow_boxed=true)) for + any_jltypes(convert(LLVM.LLVMType, b; allow_boxed = true)) for b in ActiveRetTypes ) NT = AnonymousStruct(NT) @@ -2922,7 +2922,7 @@ function create_abi_wrapper( dretTy = LLVM.LLVMType( API.EnzymeGetShadowType( width, - convert(LLVMType, actualRetType; allow_boxed=!(rettype <: Active)), + convert(LLVMType, actualRetType; allow_boxed = !(rettype <: Active)), ), ) push!(T_wrapperargs, dretTy) @@ -2976,7 +2976,7 @@ function create_abi_wrapper( rty = if Base.isconcretetype(literal_rt) Base.RefValue{literal_rt} else - (Base.RefValue{T} where {T<:literal_rt}) + (Base.RefValue{T} where T <: literal_rt) end if width == 1 push!(sret_types, rty) @@ -3012,7 +3012,7 @@ function create_abi_wrapper( combinedReturn = if any( - any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types + any_jltypes(convert(LLVM.LLVMType, T; allow_boxed = true)) for T in sret_types ) AnonymousStruct(Tuple{sret_types...}) else @@ -3055,7 +3055,7 @@ function create_abi_wrapper( end if tape != C_NULL tape = LLVM.LLVMType(tape) - jltape = convert(LLVM.LLVMType, Compiler.tape_type(tape); allow_boxed=true) + jltape = convert(LLVM.LLVMType, Compiler.tape_type(tape); allow_boxed = true) push!(T_wrapperargs, jltape) else needs_tape = false @@ -3117,7 +3117,7 @@ function create_abi_wrapper( llty = value_type(params[i]) - convty = convert(LLVMType, T′; allow_boxed=true) + convty = convert(LLVMType, T′; allow_boxed = true) if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType)) @assert Base.isconcretetype(T′) @@ -3661,7 +3661,7 @@ function lower_convention( entry_f::LLVM.Function, @nospecialize(actualRetType::Type), @nospecialize(RetActivity::Type), - @nospecialize(TT::Union{Type,Nothing}), + @nospecialize(TT::Union{Type, Nothing}), run_enzyme::Bool, ) entry_ft = LLVM.function_type(entry_f) @@ -3686,11 +3686,11 @@ function lower_convention( returnRoots = returnRoots !== nothing loweredReturn = RetActivity <: Active && !allocatedinline(actualRetType) - if (RetActivity <: Active || RetActivity <: MixedDuplicated || RetActivity <: BatchMixedDuplicated) && (allocatedinline(actualRetType) != allocatedinline(eltype(RetActivity))) - @assert !allocatedinline(actualRetType) - loweredReturn = true + if (RetActivity <: Active || RetActivity <: MixedDuplicated || RetActivity <: BatchMixedDuplicated) && (allocatedinline(actualRetType) != allocatedinline(eltype(RetActivity))) + @assert !allocatedinline(actualRetType) + loweredReturn = true end - + expected_RT = Nothing if loweredReturn @assert !sret @@ -3892,7 +3892,7 @@ function lower_convention( if RetActivity <: Const metadata(sretPtr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end - + typeTree = copy(typetree(actualRetType, ctx, dl, seen)) merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) only!(typeTree, -1) @@ -3932,7 +3932,7 @@ function lower_convention( metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end ctx = LLVM.context(entry_f) - + typeTree = copy(typetree(arg.typ, ctx, dl, seen)) merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) only!(typeTree, -1) @@ -4176,7 +4176,7 @@ function lower_convention( position!(builder, failure) - emit_error(builder, nothing, "Expected return type of primal to be " * string(expected_RT) * " but did not find a value of that type") + emit_error(builder, nothing, "Expected return type of primal to be "*string(expected_RT)*" but did not find a value of that type") unreachable!(builder) else push!( @@ -4415,7 +4415,7 @@ end using Random # returns arg, return -function no_type_setting(@nospecialize(specTypes::Type{<:Tuple}); world=nothing) +function no_type_setting(@nospecialize(specTypes::Type{<:Tuple}); world = nothing) # Even though the julia type here is ptr{int8}, the actual data can be something else if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_simd) return (true, false) @@ -4434,7 +4434,7 @@ const DumpPreOpt = Ref(false) function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeTarget}) @assert output == :llvm - + config = job.config params = config.params @@ -4472,14 +4472,14 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT primal_config = CompilerConfig( primal_target, primal_params; - toplevel=config.toplevel, - always_inline=config.always_inline, - kernel=false, - libraries=true, - optimize=false, - cleanup=false, - only_entry=false, - validate=false, + toplevel = config.toplevel, + always_inline = config.always_inline, + kernel = false, + libraries = true, + optimize = false, + cleanup = false, + only_entry = false, + validate = false, # ??? entry_abi ) primal_job = CompilerJob(primal, primal_config, job.world) @@ -4731,13 +4731,13 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT parallel = false process_module = false device_module = false - if primal_target isa GPUCompiler.NativeCompilerTarget - parallel = Base.Threads.nthreads() > 1 + if primal_target isa GPUCompiler.NativeCompilerTarget + parallel = Base.Threads.nthreads() > 1 else # All other targets are GPU targets parallel = true device_module = true - + if primal_target isa GPUCompiler.GCNCompilerTarget || primal_target isa GPUCompiler.MetalCompilerTarget process_module = true @@ -4782,7 +4782,7 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT ctx = LLVM.context(mod) for f in functions(mod), bb in blocks(f), inst in instructions(bb) fn = isa(inst, LLVM.CallInst) ? LLVM.called_operand(inst) : nothing - + if !API.HasFromStack(inst) && isa(inst, LLVM.AllocaInst) calluse = nothing @@ -4824,7 +4824,7 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT if !API.HasFromStack(inst) && ((isa(inst, LLVM.CallInst) && - (!isa(fn, LLVM.Function) || isempty(blocks(fn)))) || isa(inst, LLVM.LoadInst) || isa(inst, LLVM.AllocaInst) || isa(inst, LLVM.ExtractValueInst)) + (!isa(fn, LLVM.Function) || isempty(blocks(fn))) ) || isa(inst, LLVM.LoadInst) || isa(inst, LLVM.AllocaInst) || isa(inst, LLVM.ExtractValueInst)) legal, source_typ, byref = abs_typeof(inst) codegen_typ = value_type(inst) if legal @@ -4853,15 +4853,14 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT metadata(inst)["enzyme_type"] = to_md(ec, ctx) metadata(inst)["enzymejl_source_type_$(source_typ)"] = MDNode(LLVM.Metadata[]) metadata(inst)["enzymejl_byref_$(byref)"] = MDNode(LLVM.Metadata[]) - - @static if VERSION < v"1.11-" - else - legal2, obj = absint(inst) - if legal2 - obj isa Memory && obj == typeof(obj).instance - metadata(inst)["nonnull"] = MDNode(LLVM.Metadata[]) - end + +@static if VERSION < v"1.11-" +else + legal2, obj = absint(inst) + if legal2 obj isa Memory && obj == typeof(obj).instance + metadata(inst)["nonnull"] = MDNode(LLVM.Metadata[]) end +end end @@ -5094,7 +5093,7 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT adjointf, augmented_primalf, TapeType = enzyme!( job, - interp, + interp, mod, primalf, TT, @@ -5261,12 +5260,12 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT isempty(LLVM.blocks(fn)) && continue linkage!(fn, LLVM.API.LLVMLinkerPrivateLinkage) end - + delete!(mod_to_edges, mod) use_primal = mode == API.DEM_ReverseModePrimal entry = use_primal ? augmented_primalf : adjointf - return mod, (; adjointf, augmented_primalf, entry, compiled=meta.compiled, TapeType, edges) + return mod, (; adjointf, augmented_primalf, entry, compiled = meta.compiled, TapeType, edges) end # Compiler result @@ -5419,265 +5418,265 @@ end ::Type{TapeType}, args::Vararg{Any,N}, ) where {RawCall,PT,FA,T,RT,TapeType,N,CC,width,returnPrimal} - F = eltype(FA) - is_forward = - CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk - is_adjoint = CC <: AdjointThunk || CC <: CombinedAdjointThunk - is_split = CC <: AdjointThunk || CC <: AugmentedForwardThunk - needs_tape = CC <: AdjointThunk - - argtt = tt.parameters[1] - rettype = rt.parameters[1] - argtypes = DataType[argtt.parameters...] - argexprs = Union{Expr,Symbol}[:(args[$i]) for i = 1:N] - - if false && CC <: PrimalErrorThunk - primargs = [ - quote - convert($(eltype(T)), $(argexprs[i]).val) - end for (i, T) in enumerate(argtypes) - ] - return quote - fn.val($(primargs...)) - error( - "Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up", - ) + F = eltype(FA) + is_forward = + CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk + is_adjoint = CC <: AdjointThunk || CC <: CombinedAdjointThunk + is_split = CC <: AdjointThunk || CC <: AugmentedForwardThunk + needs_tape = CC <: AdjointThunk + + argtt = tt.parameters[1] + rettype = rt.parameters[1] + argtypes = DataType[argtt.parameters...] + argexprs = Union{Expr,Symbol}[:(args[$i]) for i = 1:N] + + if false && CC <: PrimalErrorThunk + primargs = [ + quote + convert($(eltype(T)), $(argexprs[i]).val) + end for (i, T) in enumerate(argtypes) + ] + return quote + fn.val($(primargs...)) + error( + "Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up", + ) + end end - end - if !RawCall && !(CC <: PrimalErrorThunk) - if rettype <: Active || - rettype <: MixedDuplicated || - rettype <: BatchMixedDuplicated - if length(argtypes) + is_adjoint + needs_tape != length(argexprs) - return quote - throw(MethodError($CC(fptr), (fn, args...))) + if !RawCall && !(CC <: PrimalErrorThunk) + if rettype <: Active || + rettype <: MixedDuplicated || + rettype <: BatchMixedDuplicated + if length(argtypes) + is_adjoint + needs_tape != length(argexprs) + return quote + throw(MethodError($CC(fptr), (fn, args...))) + end end - end - elseif rettype <: Const - if length(argtypes) + needs_tape != length(argexprs) - return quote - throw(MethodError($CC(fptr), (fn, args...))) + elseif rettype <: Const + if length(argtypes) + needs_tape != length(argexprs) + return quote + throw(MethodError($CC(fptr), (fn, args...))) + end end - end - else - if length(argtypes) + needs_tape != length(argexprs) - return quote - throw(MethodError($CC(fptr), (fn, args...))) + else + if length(argtypes) + needs_tape != length(argexprs) + return quote + throw(MethodError($CC(fptr), (fn, args...))) + end end end end - end - types = DataType[] - - if !(rettype <: Const) && ( - isghostty(eltype(rettype)) || - Core.Compiler.isconstType(eltype(rettype)) || - eltype(rettype) === DataType - ) - rrt = eltype(rettype) - error("Return type `$rrt` not marked Const, but is ghost or const type.") - end + types = DataType[] - sret_types = Type[] # Julia types of all returned variables - # By ref values we create and need to preserve - ccexprs = Union{Expr,Symbol}[] # The expressions passed to the `llvmcall` + if !(rettype <: Const) && ( + isghostty(eltype(rettype)) || + Core.Compiler.isconstType(eltype(rettype)) || + eltype(rettype) === DataType + ) + rrt = eltype(rettype) + error("Return type `$rrt` not marked Const, but is ghost or const type.") + end - if !isghostty(F) && !Core.Compiler.isconstType(F) - isboxed = GPUCompiler.deserves_argbox(F) - argexpr = :(fn.val) + sret_types = Type[] # Julia types of all returned variables + # By ref values we create and need to preserve + ccexprs = Union{Expr,Symbol}[] # The expressions passed to the `llvmcall` - if isboxed - push!(types, Any) - else - push!(types, F) - end + if !isghostty(F) && !Core.Compiler.isconstType(F) + isboxed = GPUCompiler.deserves_argbox(F) + argexpr = :(fn.val) - push!(ccexprs, argexpr) - if (FA <: Active) - return quote - error("Cannot have function with Active annotation, $FA") - end - elseif !(FA <: Const) - argexpr = :(fn.dval) - F_ABI = F - if width == 1 - if (FA <: MixedDuplicated) - push!(types, Any) - else - push!(types, F_ABI) - end + if isboxed + push!(types, Any) else - if F_ABI <: BatchMixedDuplicated - F_ABI = Base.RefValue{F_ABI} + push!(types, F) + end + + push!(ccexprs, argexpr) + if (FA <: Active) + return quote + error("Cannot have function with Active annotation, $FA") end - F_ABI = NTuple{width,F_ABI} - isboxedvec = GPUCompiler.deserves_argbox(F_ABI) - if isboxedvec - push!(types, Any) + elseif !(FA <: Const) + argexpr = :(fn.dval) + F_ABI = F + if width == 1 + if (FA <: MixedDuplicated) + push!(types, Any) + else + push!(types, F_ABI) + end else - push!(types, F_ABI) + if F_ABI <: BatchMixedDuplicated + F_ABI = Base.RefValue{F_ABI} + end + F_ABI = NTuple{width, F_ABI} + isboxedvec = GPUCompiler.deserves_argbox(F_ABI) + if isboxedvec + push!(types, Any) + else + push!(types, F_ABI) + end end + push!(ccexprs, argexpr) end - push!(ccexprs, argexpr) end - end - i = 1 - ActiveRetTypes = Type[] + i = 1 + ActiveRetTypes = Type[] - for T in argtypes - source_typ = eltype(T) + for T in argtypes + source_typ = eltype(T) - expr = argexprs[i] - i += 1 - if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) - @assert T <: Const - if is_adjoint - push!(ActiveRetTypes, Nothing) + expr = argexprs[i] + i += 1 + if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) + @assert T <: Const + if is_adjoint + push!(ActiveRetTypes, Nothing) + end + continue end - continue - end - - isboxed = GPUCompiler.deserves_argbox(source_typ) - argexpr = if RawCall - expr - else - Expr(:., expr, QuoteNode(:val)) - end + isboxed = GPUCompiler.deserves_argbox(source_typ) - if isboxed - push!(types, Any) - else - push!(types, source_typ) - end - - push!(ccexprs, argexpr) - - if T <: Const || T <: BatchDuplicatedFunc - if is_adjoint - push!(ActiveRetTypes, Nothing) - end - continue - end - if CC <: PrimalErrorThunk - continue - end - if T <: Active - if is_adjoint - if width == 1 - push!(ActiveRetTypes, source_typ) - else - push!(ActiveRetTypes, NTuple{width,source_typ}) - end - end - elseif T <: Duplicated || T <: DuplicatedNoNeed - if RawCall - argexpr = argexprs[i] - i += 1 + argexpr = if RawCall + expr else - argexpr = Expr(:., expr, QuoteNode(:dval)) + Expr(:., expr, QuoteNode(:val)) end + if isboxed push!(types, Any) else push!(types, source_typ) end - if is_adjoint - push!(ActiveRetTypes, Nothing) - end - push!(ccexprs, argexpr) - elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed - if RawCall - argexpr = argexprs[i] - i += 1 - else - argexpr = Expr(:., expr, QuoteNode(:dval)) - end - isboxedvec = GPUCompiler.deserves_argbox(NTuple{width,source_typ}) - if isboxedvec - push!(types, Any) - else - push!(types, NTuple{width,source_typ}) - end - if is_adjoint - push!(ActiveRetTypes, Nothing) - end + push!(ccexprs, argexpr) - elseif T <: MixedDuplicated - if RawCall - argexpr = argexprs[i] - i += 1 - else - argexpr = Expr(:., expr, QuoteNode(:dval)) - end - push!(types, Any) - if is_adjoint - push!(ActiveRetTypes, Nothing) + + if T <: Const || T <: BatchDuplicatedFunc + if is_adjoint + push!(ActiveRetTypes, Nothing) + end + continue end - push!(ccexprs, argexpr) - elseif T <: BatchMixedDuplicated - if RawCall - argexpr = argexprs[i] - i += 1 - else - argexpr = Expr(:., expr, QuoteNode(:dval)) + if CC <: PrimalErrorThunk + continue end - isboxedvec = - GPUCompiler.deserves_argbox(NTuple{width,Base.RefValue{source_typ}}) - if isboxedvec + if T <: Active + if is_adjoint + if width == 1 + push!(ActiveRetTypes, source_typ) + else + push!(ActiveRetTypes, NTuple{width,source_typ}) + end + end + elseif T <: Duplicated || T <: DuplicatedNoNeed + if RawCall + argexpr = argexprs[i] + i += 1 + else + argexpr = Expr(:., expr, QuoteNode(:dval)) + end + if isboxed + push!(types, Any) + else + push!(types, source_typ) + end + if is_adjoint + push!(ActiveRetTypes, Nothing) + end + push!(ccexprs, argexpr) + elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed + if RawCall + argexpr = argexprs[i] + i += 1 + else + argexpr = Expr(:., expr, QuoteNode(:dval)) + end + isboxedvec = GPUCompiler.deserves_argbox(NTuple{width,source_typ}) + if isboxedvec + push!(types, Any) + else + push!(types, NTuple{width,source_typ}) + end + if is_adjoint + push!(ActiveRetTypes, Nothing) + end + push!(ccexprs, argexpr) + elseif T <: MixedDuplicated + if RawCall + argexpr = argexprs[i] + i += 1 + else + argexpr = Expr(:., expr, QuoteNode(:dval)) + end push!(types, Any) + if is_adjoint + push!(ActiveRetTypes, Nothing) + end + push!(ccexprs, argexpr) + elseif T <: BatchMixedDuplicated + if RawCall + argexpr = argexprs[i] + i += 1 + else + argexpr = Expr(:., expr, QuoteNode(:dval)) + end + isboxedvec = + GPUCompiler.deserves_argbox(NTuple{width,Base.RefValue{source_typ}}) + if isboxedvec + push!(types, Any) + else + push!(types, NTuple{width,Base.RefValue{source_typ}}) + end + if is_adjoint + push!(ActiveRetTypes, Nothing) + end + push!(ccexprs, argexpr) else - push!(types, NTuple{width,Base.RefValue{source_typ}}) + error("calling convention should be annotated, got $T") end - if is_adjoint - push!(ActiveRetTypes, Nothing) - end - push!(ccexprs, argexpr) - else - error("calling convention should be annotated, got $T") end - end - jlRT = eltype(rettype) - if typeof(jlRT) == UnionAll - # Future improvement, add type assertion on load - jlRT = DataType - end + jlRT = eltype(rettype) + if typeof(jlRT) == UnionAll + # Future improvement, add type assertion on load + jlRT = DataType + end - if is_sret_union(jlRT) - jlRT = Any - end + if is_sret_union(jlRT) + jlRT = Any + end - # API.DFT_OUT_DIFF - if is_adjoint - if rettype <: Active || - rettype <: MixedDuplicated || - rettype <: BatchMixedDuplicated - # TODO handle batch width - if rettype <: Active - @assert allocatedinline(jlRT) - end - j_drT = if width == 1 - jlRT - else - NTuple{width,jlRT} + # API.DFT_OUT_DIFF + if is_adjoint + if rettype <: Active || + rettype <: MixedDuplicated || + rettype <: BatchMixedDuplicated + # TODO handle batch width + if rettype <: Active + @assert allocatedinline(jlRT) + end + j_drT = if width == 1 + jlRT + else + NTuple{width,jlRT} + end + push!(types, j_drT) + push!(ccexprs, argexprs[i]) + i += 1 end - push!(types, j_drT) - push!(ccexprs, argexprs[i]) - i += 1 end - end - if needs_tape - if !(isghostty(TapeType) || Core.Compiler.isconstType(TapeType)) - push!(types, TapeType) - push!(ccexprs, argexprs[i]) + if needs_tape + if !(isghostty(TapeType) || Core.Compiler.isconstType(TapeType)) + push!(types, TapeType) + push!(ccexprs, argexprs[i]) + end + i += 1 end - i += 1 - end ts_ctx = JuliaContext() ctx = context(ts_ctx) @@ -5687,7 +5686,7 @@ end if is_adjoint NT = Tuple{ActiveRetTypes...} if any( - any_jltypes(convert(LLVM.LLVMType, b; allow_boxed=true)) for + any_jltypes(convert(LLVM.LLVMType, b; allow_boxed = true)) for b in ActiveRetTypes ) NT = AnonymousStruct(NT) @@ -5717,7 +5716,7 @@ end rty = if Base.isconcretetype(jlRT) Base.RefValue{jlRT} else - (Base.RefValue{T} where {T<:jlRT}) + (Base.RefValue{T} where T <: jlRT) end push!(sret_types, rty) elseif rettype <: BatchDuplicated || rettype <: BatchDuplicatedNoNeed @@ -5726,7 +5725,7 @@ end rty = if Base.isconcretetype(jlRT) Base.RefValue{jlRT} else - (Base.RefValue{T} where {T<:jlRT}) + (Base.RefValue{T} where T <: jlRT) end push!(sret_types, AnonymousStruct(NTuple{width,rty})) elseif CC <: AugmentedForwardThunk @@ -5746,7 +5745,7 @@ end end # calls fptr - llvmtys = LLVMType[convert(LLVMType, x; allow_boxed=true) for x in types] + llvmtys = LLVMType[convert(LLVMType, x; allow_boxed = true) for x in types] T_void = convert(LLVMType, Nothing) @@ -5754,7 +5753,7 @@ end (CC <: PrimalErrorThunk && eltype(rettype) == Union{}) ? Union{} : Tuple{sret_types...} if any( - any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types + any_jltypes(convert(LLVM.LLVMType, T; allow_boxed = true)) for T in sret_types ) combinedReturn = AnonymousStruct(combinedReturn) end @@ -5898,7 +5897,7 @@ end # JIT ## -function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, edges::Vector{Any}, adjoint_name::String, @nospecialize(primal_name::Union{String,Nothing}), @nospecialize(TapeType), prepost::String) +function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, edges::Vector{Any}, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String) if job.config.params.ABI <: InlineABI return CompileResult( Val((Symbol(mod), Symbol(adjoint_name))), @@ -5943,7 +5942,7 @@ const DumpPrePostOpt = Ref(false) const DumpPostOpt = Ref(false) # actual compilation -function _thunk(job, postopt::Bool=true)::Tuple{LLVM.Module,Vector{Any},String,Union{String,Nothing},Type,String} +function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, String, Union{String, Nothing}, Type, String} config = CompilerConfig(job.config; optimize=false) job = CompilerJob(job.source, config, job.world) mod, meta = compile(:llvm, job) @@ -5990,7 +5989,7 @@ end const cache = Dict{UInt,CompileResult}() -const autodiff_cache = Dict{Ptr{Cvoid},Tuple{String,String}}() +const autodiff_cache = Dict{Ptr{Cvoid},Tuple{String, String}}() const cache_lock = ReentrantLock() @inline function cached_compilation(@nospecialize(job::CompilerJob))::CompileResult @@ -6019,20 +6018,20 @@ end @inline function thunkbase( mi::Core.MethodInstance, - World::Union{UInt,Nothing}, + World::Union{UInt, Nothing}, @nospecialize(FA::Type{<:Annotation}), @nospecialize(A::Type{<:Annotation}), @nospecialize(TT::Type), Mode::API.CDerivativeMode, width::Int, - @nospecialize(ModifiedBetween::(NTuple{N,Bool} where {N})), + @nospecialize(ModifiedBetween::(NTuple{N, Bool} where N)), ReturnPrimal::Bool, ShadowInit::Bool, @nospecialize(ABI::Type), ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, - edges::Union{Nothing,Vector{Any}} + edges::Union{Nothing, Vector{Any}} ) target = Compiler.EnzymeTarget() params = Compiler.EnzymeCompilerParams( @@ -6052,11 +6051,11 @@ end StrongZero ) #=abiwrap=# tmp_job = if World isa Nothing - jb = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) + jb = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false)) check_activity_cache_invalidations(jb.world) jb else - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false), World) end interp = GPUCompiler.get_interpreter(tmp_job) @@ -6106,9 +6105,9 @@ end StrongZero ) #=abiwrap=# job = if World isa Nothing - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false)) else - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false), World) end # We need to use primal as the key, to lookup the right method # but need to mixin the hash of the adjoint to avoid cache collisions @@ -6232,31 +6231,31 @@ end function thunk end -function thunk_generator(world::UInt, source::Union{Method,LineNumberNode}, @nospecialize(FA::Type), @nospecialize(A::Type), @nospecialize(TT::Type), Mode::Enzyme.API.CDerivativeMode, Width::Int, @nospecialize(ModifiedBetween::(NTuple{N,Bool} where {N})), ReturnPrimal::Bool, ShadowInit::Bool, @nospecialize(ABI::Type), ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, @nospecialize(self), @nospecialize(fakeworld), @nospecialize(fa::Type), @nospecialize(a::Type), @nospecialize(tt::Type), @nospecialize(mode::Type), @nospecialize(width::Type), @nospecialize(modifiedbetween::Type), @nospecialize(returnprimal::Type), @nospecialize(shadowinit::Type), @nospecialize(abi::Type), @nospecialize(erriffuncwritten::Type), @nospecialize(runtimeactivity::Type), @nospecialize(strongzero::Type)) +function thunk_generator(world::UInt, source::Union{Method, LineNumberNode}, @nospecialize(FA::Type), @nospecialize(A::Type), @nospecialize(TT::Type), Mode::Enzyme.API.CDerivativeMode, Width::Int, @nospecialize(ModifiedBetween::(NTuple{N, Bool} where N)), ReturnPrimal::Bool, ShadowInit::Bool, @nospecialize(ABI::Type), ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, @nospecialize(self), @nospecialize(fakeworld), @nospecialize(fa::Type), @nospecialize(a::Type), @nospecialize(tt::Type), @nospecialize(mode::Type), @nospecialize(width::Type), @nospecialize(modifiedbetween::Type), @nospecialize(returnprimal::Type), @nospecialize(shadowinit::Type), @nospecialize(abi::Type), @nospecialize(erriffuncwritten::Type), @nospecialize(runtimeactivity::Type), @nospecialize(strongzero::Type)) @nospecialize - - slotnames = Core.svec(Symbol("#self#"), - :fakeworld, :fa, :a, :tt, :mode, :width, - :modifiedbetween, :returnprimal, :shadowinit, - :abi, :erriffuncwritten, :runtimeactivity, :strongzero) + + slotnames = Core.svec(Symbol("#self#"), + :fakeworld, :fa, :a, :tt, :mode, :width, + :modifiedbetween, :returnprimal, :shadowinit, + :abi, :erriffuncwritten, :runtimeactivity, :strongzero) stub = Core.GeneratedFunctionStub(thunk, slotnames, Core.svec()) ft = eltype(FA) primal_tt = Tuple{map(eltype, TT.parameters)...} # look up the method match - + min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) - + mi = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, ft, primal_tt, world, min_world, max_world) - + mi === nothing && return stub(world, source, :(throw(MethodError($ft, $primal_tt, $world)))) - + check_activity_cache_invalidations(world) edges = Any[] add_edge!(edges, mi) - + ts_ctx = JuliaContext() ctx = context(ts_ctx) activate(ctx) @@ -6289,23 +6288,23 @@ function thunk_generator(world::UInt, source::Union{Method,LineNumberNode}, @nos if Mode == API.DEM_ForwardMode - fwd_sig = Tuple{typeof(EnzymeRules.forward),<:EnzymeRules.FwdConfig,<:Enzyme.EnzymeCore.Annotation,Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} + fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} add_edge!(edges, fwd_sig) else - rev_sig = Tuple{typeof(EnzymeRules.augmented_primal),<:EnzymeRules.RevConfig,<:Enzyme.EnzymeCore.Annotation,Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} + rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} add_edge!(edges, rev_sig) - - rev_sig = Tuple{typeof(EnzymeRules.reverse),<:EnzymeRules.RevConfig,<:Enzyme.EnzymeCore.Annotation,Union{Type{<:Enzyme.EnzymeCore.Annotation},Enzyme.EnzymeCore.Active},Any,Vararg{Enzyme.EnzymeCore.Annotation}} + + rev_sig = Tuple{typeof(EnzymeRules.reverse), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Union{Type{<:Enzyme.EnzymeCore.Annotation}, Enzyme.EnzymeCore.Active}, Any, Vararg{Enzyme.EnzymeCore.Annotation}} add_edge!(edges, rev_sig) end - - ina_sig = Tuple{typeof(EnzymeRules.inactive),Vararg{Any}} + + ina_sig = Tuple{typeof(EnzymeRules.inactive), Vararg{Any}} add_edge!(edges, ina_sig) - + for gen_sig in ( - Tuple{typeof(EnzymeRules.inactive_noinl),Vararg{Any}}, - Tuple{typeof(EnzymeRules.noalias),Vararg{Any}}, - Tuple{typeof(EnzymeRules.inactive_type),Type}, + Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Any}}, + Tuple{typeof(EnzymeRules.noalias), Vararg{Any}}, + Tuple{typeof(EnzymeRules.inactive_type), Type}, ) add_edge!(edges, gen_sig) end @@ -6350,27 +6349,27 @@ import GPUCompiler: deferred_codegen_jobs function deferred_id_codegen end -function deferred_id_generator(world::UInt, source::Union{Method,LineNumberNode}, @nospecialize(FA::Type), @nospecialize(A::Type), @nospecialize(TT::Type), Mode::Enzyme.API.CDerivativeMode, Width::Int, @nospecialize(ModifiedBetween::(NTuple{N,Bool} where {N})), ReturnPrimal::Bool, ShadowInit::Bool, @nospecialize(ExpectedTapeType::Type), ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, @nospecialize(self), @nospecialize(fa::Type), @nospecialize(a::Type), @nospecialize(tt::Type), @nospecialize(mode::Type), @nospecialize(width::Type), @nospecialize(modifiedbetween::Type), @nospecialize(returnprimal::Type), @nospecialize(shadowinit::Type), @nospecialize(expectedtapetype::Type), @nospecialize(erriffuncwritten::Type), @nospecialize(runtimeactivity::Type), @nospecialize(strongzero::Type)) +function deferred_id_generator(world::UInt, source::Union{Method, LineNumberNode}, @nospecialize(FA::Type), @nospecialize(A::Type), @nospecialize(TT::Type), Mode::Enzyme.API.CDerivativeMode, Width::Int, @nospecialize(ModifiedBetween::(NTuple{N, Bool} where N)), ReturnPrimal::Bool, ShadowInit::Bool, @nospecialize(ExpectedTapeType::Type), ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, @nospecialize(self), @nospecialize(fa::Type), @nospecialize(a::Type), @nospecialize(tt::Type), @nospecialize(mode::Type), @nospecialize(width::Type), @nospecialize(modifiedbetween::Type), @nospecialize(returnprimal::Type), @nospecialize(shadowinit::Type), @nospecialize(expectedtapetype::Type), @nospecialize(erriffuncwritten::Type), @nospecialize(runtimeactivity::Type), @nospecialize(strongzero::Type)) @nospecialize - + slotnames = Core.svec(Symbol("#self#"), - :fa, :a, :tt, :mode, :width, :modifiedbetween, - :returnprimal, :shadowinit, :expectedtapetype, - :erriffuncwritten, :runtimeactivity, :strongzero) + :fa, :a, :tt, :mode, :width, :modifiedbetween, + :returnprimal, :shadowinit, :expectedtapetype, + :erriffuncwritten, :runtimeactivity, :strongzero) stub = Core.GeneratedFunctionStub(deferred_id_generator, slotnames, Core.svec()) ft = eltype(FA) primal_tt = Tuple{map(eltype, TT.parameters)...} # look up the method match - + min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) - + mi = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, ft, primal_tt, world, min_world, max_world) - + mi === nothing && return stub(world, source, :(throw(MethodError($ft, $primal_tt, $world)))) - + target = EnzymeTarget() rt2 = if A isa UnionAll rrt = primal_return_type_world(Mode == API.DEM_ForwardMode ? Forward : Reverse, world, mi) @@ -6410,7 +6409,7 @@ function deferred_id_generator(world::UInt, source::Union{Method,LineNumberNode} StrongZero ) #=abiwrap=# job = - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), world) + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false), world) addr = get_trampoline(job) id = Base.reinterpret(Int, pointer(addr)) From d1e792bab8c3861e27962837147674f830c70c96 Mon Sep 17 00:00:00 2001 From: Yousof Mardoukhi Date: Sat, 25 Oct 2025 00:51:10 +0200 Subject: [PATCH 4/8] fix: added back removed `has_easy_rule_from_sig`. --- src/compiler.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 969e803dab..a7e9ad57cd 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -513,6 +513,9 @@ function prepare_llvm(interp, mod::LLVM.Module, job, meta) attributes, StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(RT)))), ) + if EnzymeRules.has_easy_rule_from_sig(Interpreter.simplify_kw(mi.specTypes); job.world) + push!(attributes, LLVM.StringAttribute("enzyme_LocalReadOnlyOrThrow")) + end if returnRoots attr = StringAttribute("enzymejl_returnRoots", "") push!(parameter_attributes(llvmfn, 2), attr) @@ -1008,6 +1011,7 @@ end Duplicated, nothing, run_enzyme, + world ) if cur state.primalf = llvmfn @@ -3663,6 +3667,7 @@ function lower_convention( @nospecialize(RetActivity::Type), @nospecialize(TT::Union{Type, Nothing}), run_enzyme::Bool, + world::UInt ) entry_ft = LLVM.function_type(entry_f) @@ -4222,7 +4227,9 @@ function lower_convention( attributes, StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(rt)))), ) - + if EnzymeRules.has_easy_rule_from_sig(Interpreter.simplify_kw(mi.specTypes); world) + push!(attributes, LLVM.StringAttribute("enzyme_LocalReadOnlyOrThrow")) + end for prev in collect(function_attributes(entry_f)) if kind(prev) == kind(StringAttribute("enzyme_ta_norecur")) push!(attributes, prev) @@ -4720,6 +4727,7 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT job.config.params.rt, TT, params.run_enzyme, + job.world ) end @@ -6475,3 +6483,4 @@ end include("compiler/reflection.jl") end + From ef9878e00269f9943190f65366ef7195bc0f6dba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= <765740+giordano@users.noreply.github.com> Date: Fri, 24 Oct 2025 23:54:54 +0100 Subject: [PATCH 5/8] Remove extra newline --- src/compiler.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index a7e9ad57cd..64f24da593 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6483,4 +6483,3 @@ end include("compiler/reflection.jl") end - From 47582582c15b382a8253d0f50cadd5a2a90c047f Mon Sep 17 00:00:00 2001 From: Yousof Mardoukhi Date: Sat, 25 Oct 2025 01:01:17 +0200 Subject: [PATCH 6/8] refactor: removed debugging `@show` statement. --- test/embedded_bitcode.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/embedded_bitcode.jl b/test/embedded_bitcode.jl index 107aad221d..9ceec0a84f 100644 --- a/test/embedded_bitcode.jl +++ b/test/embedded_bitcode.jl @@ -77,8 +77,6 @@ end err_llvmir = nothing b = @view a[1:5] - @show errstream - redirect_stdio(stdout=errstream, stderr=errstream, stdin=devnull) do try gradient(Reverse, func_ccall, Const(0.0), b) From 7f3b57d0bf54847787cc62870237898895b1d80b Mon Sep 17 00:00:00 2001 From: Yousof Mardoukhi Date: Sat, 25 Oct 2025 01:02:22 +0200 Subject: [PATCH 7/8] refactor: removed commented lines. --- test/embedded_bitcode.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/embedded_bitcode.jl b/test/embedded_bitcode.jl index 9ceec0a84f..c1dbf0ab76 100644 --- a/test/embedded_bitcode.jl +++ b/test/embedded_bitcode.jl @@ -82,8 +82,6 @@ end gradient(Reverse, func_ccall, Const(0.0), b) catch e err_llvmir = e - # finally - # redirect_stdout(old_stdout) end @test err_llvmir !== nothing From 0d1f93478339a83b25a5bd75fe61a7c3f71d7ada Mon Sep 17 00:00:00 2001 From: Yousof Mardoukhi Date: Sat, 25 Oct 2025 11:52:30 +0200 Subject: [PATCH 8/8] fix: dropped the compat for `Clang_jll`. --- test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 9e6b6395ac..fc18f306b2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -31,6 +31,5 @@ Enzyme = {path = ".."} EnzymeTestUtils = {path = "../lib/EnzymeTestUtils"} [compat] -Clang_jll = "16.0.6" EnzymeTestUtils = "0.2.1" ParallelTestRunner = "1.0.1"