diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 231c99f993..5470b1eebd 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -178,13 +178,11 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) ModifiedBetween = Val(falses_from_args(Val(1), args...)) tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} - world = GPUCompiler.codegen_world_age(Core.Typeof(f.val), tt) - if A <: Active tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} rt = Core.Compiler.return_type(f.val, tt) if !allocatedinline(rt) || rt isa Union - forward, adjoint = Enzyme.Compiler.thunk(Val(world), FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true)) + forward, adjoint = Enzyme.Compiler.thunk(FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true)) res = forward(f, args′...) tape = res[1] if ReturnPrimal @@ -196,7 +194,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) elseif A <: Duplicated || A<: DuplicatedNoNeed || A <: BatchDuplicated || A<: BatchDuplicatedNoNeed throw(ErrorException("Duplicated Returns not yet handled")) end - thunk = Enzyme.Compiler.thunk(Val(world), FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal)) + thunk = Enzyme.Compiler.thunk(FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal)) if A <: Active tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} rt = Core.Compiler.return_type(f.val, tt) @@ -315,9 +313,8 @@ f(x) = x*x ModifiedBetween = Val(falses_from_args(Val(1), args...)) tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} - world = GPUCompiler.codegen_world_age(Core.Typeof(f.val), tt) - thunk = Enzyme.Compiler.thunk(Val(world), FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), + thunk = Enzyme.Compiler.thunk(FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal) thunk(f, args′...) end @@ -337,8 +334,6 @@ code, as well as high-order differentiation. end tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} - world = GPUCompiler.codegen_world_age(Core.Typeof(f.val), tt) - if A isa UnionAll rt = Core.Compiler.return_type(f.val, tt) rt = A{rt} @@ -353,7 +348,7 @@ code, as well as high-order differentiation. ModifiedBetween = Val(falses_from_args(Val(1), args...)) - adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal)) + adjoint_ptr, primal_ptr = Compiler.deferred_codegen(FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal)) @assert primal_ptr === nothing thunk = Compiler.CombinedAdjointThunk{FA, rt, tt′, typeof(Val(width)), Val(ReturnPrimal)}(adjoint_ptr) if rt <: Active @@ -397,7 +392,6 @@ code, as well as high-order differentiation. end tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} - world = GPUCompiler.codegen_world_age(Core.Typeof(f.val), tt) if RT isa UnionAll rt = Core.Compiler.return_type(f.val, tt) @@ -419,7 +413,7 @@ code, as well as high-order differentiation. ModifiedBetween = Val(falses_from_args(Val(1), args...)) - adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal) + adjoint_ptr, primal_ptr = Compiler.deferred_codegen(FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal) @assert primal_ptr === nothing thunk = Compiler.ForwardModeThunk{FA, rt, tt′, typeof(Val(width)), ReturnPrimal}(adjoint_ptr) thunk(f, args′...) @@ -442,7 +436,6 @@ Like [`autodiff_deferred`](@ref) but will try to guess the activity of the retur @inline function autodiff_deferred(mode::M, f::FA, args...) where {FA<:Annotation, M<:Mode} args′ = annotate(args...) tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} - world = GPUCompiler.codegen_world_age(Core.Typeof(f.val), tt) rt = Core.Compiler.return_type(f.val, tt) if rt === Union{} error("return type is Union{}, giving up.") @@ -514,10 +507,9 @@ result, ∂v, ∂A tt = Tuple{map(eltype, args)...} - world = GPUCompiler.codegen_world_age(eltype(FA), tt) @assert ReturnShadow - Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false)) + Enzyme.Compiler.thunk(FA, A, Tuple{args...}, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false)) end """ @@ -578,9 +570,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated tt = Tuple{map(eltype, args)...} - world = GPUCompiler.codegen_world_age(eltype(FA), tt) - - Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false)) + Enzyme.Compiler.thunk(FA, A, Tuple{args...}, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false)) end """ @@ -648,14 +638,13 @@ result, ∂v, ∂A TT = Tuple{args...} primal_tt = Tuple{map(eltype, args)...} - world = GPUCompiler.codegen_world_age(eltype(FA), primal_tt) # TODO this assumes that the thunk here has the correct parent/etc things for getting the right cuda instructions -> same caching behavior - nondef = Enzyme.Compiler.thunk(Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal)) + nondef = Enzyme.Compiler.thunk(FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal)) TapeType = Compiler.get_tape_type(typeof(nondef[1])) A2 = Compiler.return_type(typeof(nondef[1])) - adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(A2), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType) + adjoint_ptr, primal_ptr = Compiler.deferred_codegen(FA, Val(TT), Val(A2), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType) AugT = Compiler.AugmentedForwardThunk{FA, A2, TT, Val{width}, Val(ReturnPrimal), TapeType} @assert AugT == typeof(nondef[1]) AdjT = Compiler.AdjointThunk{FA, A2, TT, Val{width}, TapeType} @@ -925,12 +914,10 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2)) tt′ = Tuple{BatchDuplicated{Core.Typeof(x), chunk}} tt = Tuple{Core.Typeof(x)} - world = GPUCompiler.codegen_world_age(Core.Typeof(f), tt) rt = Core.Compiler.return_type(f, tt) ModifiedBetween = Val((false, false)) FA = Const{Core.Typeof(f)} - World = Val(nothing) - primal, adjoint = Enzyme.Compiler.thunk(Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween) + primal, adjoint = Enzyme.Compiler.thunk(FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween) if num * chunk == n_out_val last_size = chunk @@ -938,7 +925,7 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2)) else last_size = n_out_val - (num-1)*chunk tt′ = Tuple{BatchDuplicated{Core.Typeof(x), last_size}} - primal2, adjoint2 = Enzyme.Compiler.thunk(Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween) + primal2, adjoint2 = Enzyme.Compiler.thunk(FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween) end tmp = ntuple(num) do i @@ -964,11 +951,10 @@ end @inline function jacobian(::ReverseMode, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val} tt′ = Tuple{Duplicated{Core.Typeof(x)}} tt = Tuple{Core.Typeof(x)} - world = GPUCompiler.codegen_world_age(Core.Typeof(f), tt) rt = Core.Compiler.return_type(f, tt) ModifiedBetween = Val((false, false)) FA = Const{Core.Typeof(f)} - primal, adjoint = Enzyme.Compiler.thunk(Val(world), FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween) + primal, adjoint = Enzyme.Compiler.thunk(FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween) rows = ntuple(n_outs) do i Base.@_inline_meta dx = zero(x) diff --git a/src/compiler.jl b/src/compiler.jl index 33e51325d7..084659dcec 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -314,6 +314,7 @@ using .JIT import GPUCompiler: @safe_debug, @safe_info, @safe_warn, @safe_error +snapshot = GPUCompiler.ci_cache_snapshot() safe_println(head, tail) = ccall(:jl_safe_printf, Cvoid, (Cstring, Cstring...), "%s%s\n",head, tail) macro safe_show(exs...) blk = Expr(:block) @@ -918,7 +919,7 @@ end function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, ::Val{width}) where {FT1, FT2, World, width} FT = Core.Typeof(fn) ghos = isghostty(FT) || Core.Compiler.isconstType(FT) - forward = thunk(world, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,))) + forward = thunk((ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,))) ft = ghos ? Const(fn) : Duplicated(fn, dfn) function fclosure() res = forward(ft) @@ -936,7 +937,7 @@ function runtime_newtask_augfwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, # TODO make this AD subcall type stable FT = Core.Typeof(fn) ghos = isghostty(FT) || Core.Compiler.isconstType(FT) - forward, adjoint = thunk(world, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween)) + forward, adjoint = thunk((ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween)) ft = ghos ? Const(fn) : Duplicated(fn, dfn) taperef = Ref{Any}() @@ -1085,9 +1086,8 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) dupClosure = false end - world = GPUCompiler.codegen_world_age(FT, tt) - forward = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val($ModifiedBetween), #=returnPrimal=#Val(true)) + forward = thunk((dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val($ModifiedBetween), #=returnPrimal=#Val(true)) res = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) @@ -1146,9 +1146,8 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) dupClosure = false end - world = GPUCompiler.codegen_world_age(FT, Tuple{$(ElTypes...)}) - forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, + forward, adjoint = thunk((dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true)) @@ -1253,9 +1252,8 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes) if dupClosure && (isghostty(FT) || Core.Compiler.isconstType(FT)) dupClosure = false end - world = GPUCompiler.codegen_world_age(FT, tt) - forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width, + forward, adjoint = thunk((dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true)) if tape.shadow_return !== nothing args = (args..., $shadowret) @@ -5935,6 +5933,7 @@ function emit_inacterror(B, V, orig) end function __init__() + GPUCompiler.ci_cache_insert(persistent_cache) current_task_offset() @static if VERSION < v"1.7.0" else @@ -8711,12 +8710,12 @@ end @inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated @inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed -@generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false), ::Val{ShadowInit}=Val(false)) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World} - mi = fspec(eltype(FA), TT, World) +function thunk(::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false), ::Val{ShadowInit}=Val(false)) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit} + mi = methodinstance(eltype(FA), eltype(TT)) target = Compiler.EnzymeTarget() params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType) - job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) + job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), GPUCompiler.tls_world_age()) sig = Tuple{eltype(FA), map(eltype, TT.parameters)...} @@ -8757,21 +8756,15 @@ end TapeType = thunk.TapeType AugT = AugmentedForwardThunk{FA, rt, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal), TapeType} AdjT = AdjointThunk{FA, rt, Tuple{params.TT.parameters[2:end]...}, Val{width}, TapeType} - return quote - augmented = $AugT($(thunk.primal)) - adjoint = $AdjT($(thunk.adjoint)) - (augmented, adjoint) - end + augmented = AugT(thunk.primal) + adjoint = AdjT(thunk.adjoint) + return (augmented, adjoint) elseif Mode == API.DEM_ReverseModeCombined CAdjT = CombinedAdjointThunk{FA, rt, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal)} - return quote - $CAdjT($(thunk.adjoint)) - end + return CAdjT(thunk.adjoint) elseif Mode == API.DEM_ForwardMode FMT = ForwardModeThunk{FA, rt, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal)} - return quote - $FMT($(thunk.adjoint)) - end + return FMT(thunk.adjoint) else @assert false end @@ -8779,12 +8772,12 @@ end import GPUCompiler: deferred_codegen_jobs -@generated function deferred_codegen(::Val{World}, ::Type{FA}, ::Val{tt}, ::Val{rt},::Val{Mode}, - ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false),::Val{ShadowInit}=Val(false),::Type{ExpectedTapeType}=UnknownTapeType) where {World, FA<:Annotation,tt, rt, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType} - mi = fspec(eltype(FA), tt, World) +function deferred_codegen(::Type{FA}, ::Val{tt}, ::Val{rt},::Val{Mode}, + ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false),::Val{ShadowInit}=Val(false),::Type{ExpectedTapeType}=UnknownTapeType) where {FA<:Annotation,tt, rt, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType} + mi = GPUCompiler.methodinstance(eltype(FA), eltype(tt)) target = EnzymeTarget() params = EnzymeCompilerParams(Tuple{FA, tt.parameters...}, Mode, width, remove_innerty(rt), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType) - job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) + job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), GPUCompiler.tls_world_age()) adjoint_addr, primal_addr = get_trampoline(job) adjoint_id = Base.reinterpret(Int, pointer(adjoint_addr)) @@ -8797,17 +8790,16 @@ import GPUCompiler: deferred_codegen_jobs primal_id = 0 end - quote - adjoint = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $(reinterpret(Ptr{Cvoid}, adjoint_id))) - primal = if $(primal_addr !== nothing) - ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $(reinterpret(Ptr{Cvoid}, primal_id))) - else - nothing - end - adjoint, primal + adjoint = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), (reinterpret(Ptr{Cvoid}, adjoint_id))) + primal = if (primal_addr !== nothing) + ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), (reinterpret(Ptr{Cvoid}, primal_id))) + else + nothing end + return adjoint, primal end include("compiler/reflection.jl") - +current_task_offset() +const persistent_cache = GPUCompiler.ci_cache_delta(snapshot) end