Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 12 additions & 26 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,11 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
ModifiedBetween = Val(falses_from_args(Val(1), args...))

tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
world = GPUCompiler.codegen_world_age(Core.Typeof(f.val), tt)

if A <: Active
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f.val, tt)
if !allocatedinline(rt) || rt isa Union
forward, adjoint = Enzyme.Compiler.thunk(Val(world), FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true))
forward, adjoint = Enzyme.Compiler.thunk(FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true))
res = forward(f, args′...)
tape = res[1]
if ReturnPrimal
Expand All @@ -196,7 +194,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
elseif A <: Duplicated || A<: DuplicatedNoNeed || A <: BatchDuplicated || A<: BatchDuplicatedNoNeed
throw(ErrorException("Duplicated Returns not yet handled"))
end
thunk = Enzyme.Compiler.thunk(Val(world), FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal))
thunk = Enzyme.Compiler.thunk(FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal))
if A <: Active
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f.val, tt)
Expand Down Expand Up @@ -315,9 +313,8 @@ f(x) = x*x
ModifiedBetween = Val(falses_from_args(Val(1), args...))

tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
world = GPUCompiler.codegen_world_age(Core.Typeof(f.val), tt)

thunk = Enzyme.Compiler.thunk(Val(world), FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width),
thunk = Enzyme.Compiler.thunk(FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width),
ModifiedBetween, ReturnPrimal)
thunk(f, args′...)
end
Expand All @@ -337,8 +334,6 @@ code, as well as high-order differentiation.
end
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}

world = GPUCompiler.codegen_world_age(Core.Typeof(f.val), tt)

if A isa UnionAll
rt = Core.Compiler.return_type(f.val, tt)
rt = A{rt}
Expand All @@ -353,7 +348,7 @@ code, as well as high-order differentiation.

ModifiedBetween = Val(falses_from_args(Val(1), args...))

adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal))
adjoint_ptr, primal_ptr = Compiler.deferred_codegen(FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal))
@assert primal_ptr === nothing
thunk = Compiler.CombinedAdjointThunk{FA, rt, tt′, typeof(Val(width)), Val(ReturnPrimal)}(adjoint_ptr)
if rt <: Active
Expand Down Expand Up @@ -397,7 +392,6 @@ code, as well as high-order differentiation.
end
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}

world = GPUCompiler.codegen_world_age(Core.Typeof(f.val), tt)

if RT isa UnionAll
rt = Core.Compiler.return_type(f.val, tt)
Expand All @@ -419,7 +413,7 @@ code, as well as high-order differentiation.
ModifiedBetween = Val(falses_from_args(Val(1), args...))


adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal)
adjoint_ptr, primal_ptr = Compiler.deferred_codegen(FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal)
@assert primal_ptr === nothing
thunk = Compiler.ForwardModeThunk{FA, rt, tt′, typeof(Val(width)), ReturnPrimal}(adjoint_ptr)
thunk(f, args′...)
Expand All @@ -442,7 +436,6 @@ Like [`autodiff_deferred`](@ref) but will try to guess the activity of the retur
@inline function autodiff_deferred(mode::M, f::FA, args...) where {FA<:Annotation, M<:Mode}
args′ = annotate(args...)
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
world = GPUCompiler.codegen_world_age(Core.Typeof(f.val), tt)
rt = Core.Compiler.return_type(f.val, tt)
if rt === Union{}
error("return type is Union{}, giving up.")
Expand Down Expand Up @@ -514,10 +507,9 @@ result, ∂v, ∂A

tt = Tuple{map(eltype, args)...}

world = GPUCompiler.codegen_world_age(eltype(FA), tt)

@assert ReturnShadow
Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false))
Enzyme.Compiler.thunk(FA, A, Tuple{args...}, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false))
end

"""
Expand Down Expand Up @@ -578,9 +570,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated

tt = Tuple{map(eltype, args)...}

world = GPUCompiler.codegen_world_age(eltype(FA), tt)

Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false))
Enzyme.Compiler.thunk(FA, A, Tuple{args...}, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false))
end

"""
Expand Down Expand Up @@ -648,14 +638,13 @@ result, ∂v, ∂A
TT = Tuple{args...}

primal_tt = Tuple{map(eltype, args)...}
world = GPUCompiler.codegen_world_age(eltype(FA), primal_tt)

# TODO this assumes that the thunk here has the correct parent/etc things for getting the right cuda instructions -> same caching behavior
nondef = Enzyme.Compiler.thunk(Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal))
nondef = Enzyme.Compiler.thunk(FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal))
TapeType = Compiler.get_tape_type(typeof(nondef[1]))
A2 = Compiler.return_type(typeof(nondef[1]))

adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(A2), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType)
adjoint_ptr, primal_ptr = Compiler.deferred_codegen(FA, Val(TT), Val(A2), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType)
AugT = Compiler.AugmentedForwardThunk{FA, A2, TT, Val{width}, Val(ReturnPrimal), TapeType}
@assert AugT == typeof(nondef[1])
AdjT = Compiler.AdjointThunk{FA, A2, TT, Val{width}, TapeType}
Expand Down Expand Up @@ -925,20 +914,18 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2))

tt′ = Tuple{BatchDuplicated{Core.Typeof(x), chunk}}
tt = Tuple{Core.Typeof(x)}
world = GPUCompiler.codegen_world_age(Core.Typeof(f), tt)
rt = Core.Compiler.return_type(f, tt)
ModifiedBetween = Val((false, false))
FA = Const{Core.Typeof(f)}
World = Val(nothing)
primal, adjoint = Enzyme.Compiler.thunk(Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween)
primal, adjoint = Enzyme.Compiler.thunk(FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween)

if num * chunk == n_out_val
last_size = chunk
primal2, adjoint2 = primal, adjoint
else
last_size = n_out_val - (num-1)*chunk
tt′ = Tuple{BatchDuplicated{Core.Typeof(x), last_size}}
primal2, adjoint2 = Enzyme.Compiler.thunk(Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween)
primal2, adjoint2 = Enzyme.Compiler.thunk(FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween)
end

tmp = ntuple(num) do i
Expand All @@ -964,11 +951,10 @@ end
@inline function jacobian(::ReverseMode, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val}
tt′ = Tuple{Duplicated{Core.Typeof(x)}}
tt = Tuple{Core.Typeof(x)}
world = GPUCompiler.codegen_world_age(Core.Typeof(f), tt)
rt = Core.Compiler.return_type(f, tt)
ModifiedBetween = Val((false, false))
FA = Const{Core.Typeof(f)}
primal, adjoint = Enzyme.Compiler.thunk(Val(world), FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween)
primal, adjoint = Enzyme.Compiler.thunk(FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween)
rows = ntuple(n_outs) do i
Base.@_inline_meta
dx = zero(x)
Expand Down
62 changes: 27 additions & 35 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ using .JIT

import GPUCompiler: @safe_debug, @safe_info, @safe_warn, @safe_error

snapshot = GPUCompiler.ci_cache_snapshot()
safe_println(head, tail) = ccall(:jl_safe_printf, Cvoid, (Cstring, Cstring...), "%s%s\n",head, tail)
macro safe_show(exs...)
blk = Expr(:block)
Expand Down Expand Up @@ -918,7 +919,7 @@ end
function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, ::Val{width}) where {FT1, FT2, World, width}
FT = Core.Typeof(fn)
ghos = isghostty(FT) || Core.Compiler.isconstType(FT)
forward = thunk(world, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)))
forward = thunk((ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)))
ft = ghos ? Const(fn) : Duplicated(fn, dfn)
function fclosure()
res = forward(ft)
Expand All @@ -936,7 +937,7 @@ function runtime_newtask_augfwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any,
# TODO make this AD subcall type stable
FT = Core.Typeof(fn)
ghos = isghostty(FT) || Core.Compiler.isconstType(FT)
forward, adjoint = thunk(world, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween))
forward, adjoint = thunk((ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween))
ft = ghos ? Const(fn) : Duplicated(fn, dfn)
taperef = Ref{Any}()

Expand Down Expand Up @@ -1085,9 +1086,8 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes)
dupClosure = false
end

world = GPUCompiler.codegen_world_age(FT, tt)

forward = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val($ModifiedBetween), #=returnPrimal=#Val(true))
forward = thunk((dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val($ModifiedBetween), #=returnPrimal=#Val(true))

res = forward(dupClosure ? Duplicated(f, df) : Const(f), args...)

Expand Down Expand Up @@ -1146,9 +1146,8 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes)
dupClosure = false
end

world = GPUCompiler.codegen_world_age(FT, Tuple{$(ElTypes...)})

forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT},
forward, adjoint = thunk((dupClosure ? Duplicated : Const){FT},
annotation, tt′, Val(API.DEM_ReverseModePrimal), width,
ModifiedBetween, #=returnPrimal=#Val(true))

Expand Down Expand Up @@ -1253,9 +1252,8 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes)
if dupClosure && (isghostty(FT) || Core.Compiler.isconstType(FT))
dupClosure = false
end
world = GPUCompiler.codegen_world_age(FT, tt)

forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width,
forward, adjoint = thunk((dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width,
ModifiedBetween, #=returnPrimal=#Val(true))
if tape.shadow_return !== nothing
args = (args..., $shadowret)
Expand Down Expand Up @@ -5935,6 +5933,7 @@ function emit_inacterror(B, V, orig)
end

function __init__()
GPUCompiler.ci_cache_insert(persistent_cache)
current_task_offset()
@static if VERSION < v"1.7.0"
else
Expand Down Expand Up @@ -8711,12 +8710,12 @@ end
@inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated
@inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed

@generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false), ::Val{ShadowInit}=Val(false)) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World}
mi = fspec(eltype(FA), TT, World)
function thunk(::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false), ::Val{ShadowInit}=Val(false)) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit}
mi = methodinstance(eltype(FA), eltype(TT))

target = Compiler.EnzymeTarget()
params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType)
job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World)
job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), GPUCompiler.tls_world_age())

sig = Tuple{eltype(FA), map(eltype, TT.parameters)...}

Expand Down Expand Up @@ -8757,34 +8756,28 @@ end
TapeType = thunk.TapeType
AugT = AugmentedForwardThunk{FA, rt, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal), TapeType}
AdjT = AdjointThunk{FA, rt, Tuple{params.TT.parameters[2:end]...}, Val{width}, TapeType}
return quote
augmented = $AugT($(thunk.primal))
adjoint = $AdjT($(thunk.adjoint))
(augmented, adjoint)
end
augmented = AugT(thunk.primal)
adjoint = AdjT(thunk.adjoint)
return (augmented, adjoint)
elseif Mode == API.DEM_ReverseModeCombined
CAdjT = CombinedAdjointThunk{FA, rt, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal)}
return quote
$CAdjT($(thunk.adjoint))
end
return CAdjT(thunk.adjoint)
elseif Mode == API.DEM_ForwardMode
FMT = ForwardModeThunk{FA, rt, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal)}
return quote
$FMT($(thunk.adjoint))
end
return FMT(thunk.adjoint)
else
@assert false
end
end

import GPUCompiler: deferred_codegen_jobs

@generated function deferred_codegen(::Val{World}, ::Type{FA}, ::Val{tt}, ::Val{rt},::Val{Mode},
::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false),::Val{ShadowInit}=Val(false),::Type{ExpectedTapeType}=UnknownTapeType) where {World, FA<:Annotation,tt, rt, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType}
mi = fspec(eltype(FA), tt, World)
function deferred_codegen(::Type{FA}, ::Val{tt}, ::Val{rt},::Val{Mode},
::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false),::Val{ShadowInit}=Val(false),::Type{ExpectedTapeType}=UnknownTapeType) where {FA<:Annotation,tt, rt, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType}
mi = GPUCompiler.methodinstance(eltype(FA), eltype(tt))
target = EnzymeTarget()
params = EnzymeCompilerParams(Tuple{FA, tt.parameters...}, Mode, width, remove_innerty(rt), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType)
job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World)
job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), GPUCompiler.tls_world_age())

adjoint_addr, primal_addr = get_trampoline(job)
adjoint_id = Base.reinterpret(Int, pointer(adjoint_addr))
Expand All @@ -8797,17 +8790,16 @@ import GPUCompiler: deferred_codegen_jobs
primal_id = 0
end

quote
adjoint = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $(reinterpret(Ptr{Cvoid}, adjoint_id)))
primal = if $(primal_addr !== nothing)
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $(reinterpret(Ptr{Cvoid}, primal_id)))
else
nothing
end
adjoint, primal
adjoint = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), (reinterpret(Ptr{Cvoid}, adjoint_id)))
primal = if (primal_addr !== nothing)
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), (reinterpret(Ptr{Cvoid}, primal_id)))
else
nothing
end
return adjoint, primal
end

include("compiler/reflection.jl")

current_task_offset()
const persistent_cache = GPUCompiler.ci_cache_delta(snapshot)
end