@@ -39,6 +39,41 @@ function JuliaContext(f; kwargs...)
3939end
4040
4141
42+ # # deferred compilation
43+
44+ function var"gpuc.deferred" end
45+
46+ # old, deprecated mechanism slated for removal once Enzyme is updated to the new intrinsic
47+ begin
48+ # primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
49+ # this could both be generalized (e.g. supporting actual function calls, instead of
50+ # returning a function pointer), and be integrated with the nonrecursive codegen.
51+ const deferred_codegen_jobs = Dict {Int, Any} ()
52+
53+ # We make this function explicitly callable so that we can drive OrcJIT's
54+ # lazy compilation from, while also enabling recursive compilation.
55+ Base. @ccallable Ptr{Cvoid} function deferred_codegen (ptr:: Ptr{Cvoid} )
56+ ptr
57+ end
58+
59+ @generated function deferred_codegen (:: Val{ft} , :: Val{tt} ) where {ft,tt}
60+ id = length (deferred_codegen_jobs) + 1
61+ deferred_codegen_jobs[id] = (; ft, tt)
62+ # don't bother looking up the method instance, as we'll do so again during codegen
63+ # using the world age of the parent.
64+ #
65+ # this also works around an issue on <1.10, where we don't know the world age of
66+ # generated functions so use the current world counter, which may be too new
67+ # for the world we're compiling for.
68+
69+ quote
70+ # TODO : add an edge to this method instance to support method redefinitions
71+ ccall (" extern deferred_codegen" , llvmcall, Ptr{Cvoid}, (Int,), $ id)
72+ end
73+ end
74+ end
75+
76+
4277# # compiler entrypoint
4378
4479export compile
@@ -127,33 +162,6 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
127162 error (" Unknown compilation output $output " )
128163end
129164
130- # primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
131- # this could both be generalized (e.g. supporting actual function calls, instead of
132- # returning a function pointer), and be integrated with the nonrecursive codegen.
133- const deferred_codegen_jobs = Dict {Int, Any} ()
134-
135- # We make this function explicitly callable so that we can drive OrcJIT's
136- # lazy compilation from, while also enabling recursive compilation.
137- Base. @ccallable Ptr{Cvoid} function deferred_codegen (ptr:: Ptr{Cvoid} )
138- ptr
139- end
140-
141- @generated function deferred_codegen (:: Val{ft} , :: Val{tt} ) where {ft,tt}
142- id = length (deferred_codegen_jobs) + 1
143- deferred_codegen_jobs[id] = (; ft, tt)
144- # don't bother looking up the method instance, as we'll do so again during codegen
145- # using the world age of the parent.
146- #
147- # this also works around an issue on <1.10, where we don't know the world age of
148- # generated functions so use the current world counter, which may be too new
149- # for the world we're compiling for.
150-
151- quote
152- # TODO : add an edge to this method instance to support method redefinitions
153- ccall (" extern deferred_codegen" , llvmcall, Ptr{Cvoid}, (Int,), $ id)
154- end
155- end
156-
157165const __llvm_initialized = Ref (false )
158166
159167@locked function emit_llvm (@nospecialize (job:: CompilerJob ); toplevel:: Bool ,
@@ -183,9 +191,82 @@ const __llvm_initialized = Ref(false)
183191 entry = finish_module! (job, ir, entry)
184192
185193 # deferred code generation
186- has_deferred_jobs = toplevel && ! only_entry && haskey (functions (ir), " deferred_codegen" )
194+ run_optimization_for_deferred = false
195+ if haskey (functions (ir), " gpuc.lookup" )
196+ run_optimization_for_deferred = true
197+ dyn_marker = functions (ir)[" gpuc.lookup" ]
198+
199+ # gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the
200+ # target method instance from the LLVM IR
201+ # TODO : drive deferred compilation from the Julia IR instead
202+ function find_base_object (val)
203+ while true
204+ if val isa ConstantExpr && (opcode (val) == LLVM. API. LLVMIntToPtr ||
205+ opcode (val) == LLVM. API. LLVMBitCast ||
206+ opcode (val) == LLVM. API. LLVMAddrSpaceCast)
207+ val = first (operands (val))
208+ elseif val isa LLVM. IntToPtrInst ||
209+ val isa LLVM. BitCastInst ||
210+ val isa LLVM. AddrSpaceCastInst
211+ val = first (operands (val))
212+ elseif val isa LLVM. LoadInst
213+ # In 1.11+ we no longer embed integer constants directly.
214+ gv = first (operands (val))
215+ if gv isa LLVM. GlobalValue
216+ val = LLVM. initializer (gv)
217+ continue
218+ end
219+ break
220+ else
221+ break
222+ end
223+ end
224+ return val
225+ end
226+
227+ worklist = Dict {Any, Vector{LLVM.CallInst}} ()
228+ for use in uses (dyn_marker)
229+ # decode the call
230+ call = user (use):: LLVM.CallInst
231+ dyn_mi_inst = find_base_object (operands (call)[1 ])
232+ @compiler_assert isa (dyn_mi_inst, LLVM. ConstantInt) job
233+ dyn_mi = Base. unsafe_pointer_to_objref (
234+ convert (Ptr{Cvoid}, convert (Int, dyn_mi_inst)))
235+ push! (get! (worklist, dyn_mi, LLVM. CallInst[]), call)
236+ end
237+
238+ for dyn_mi in keys (worklist)
239+ dyn_fn_name = compiled[dyn_mi]. specfunc
240+ dyn_fn = functions (ir)[dyn_fn_name]
241+
242+ # insert a pointer to the function everywhere the entry is used
243+ T_ptr = convert (LLVMType, Ptr{Cvoid})
244+ for call in worklist[dyn_mi]
245+ @dispose builder= IRBuilder () begin
246+ position! (builder, call)
247+ fptr = if LLVM. version () >= v " 17"
248+ T_ptr = LLVM. PointerType ()
249+ bitcast! (builder, dyn_fn, T_ptr)
250+ elseif VERSION >= v " 1.12.0-DEV.225"
251+ T_ptr = LLVM. PointerType (LLVM. Int8Type ())
252+ bitcast! (builder, dyn_fn, T_ptr)
253+ else
254+ ptrtoint! (builder, dyn_fn, T_ptr)
255+ end
256+ replace_uses! (call, fptr)
257+ end
258+ unsafe_delete! (LLVM. parent (call), call)
259+ end
260+ end
261+
262+ # all deferred compilations should have been resolved
263+ @compiler_assert isempty (uses (dyn_marker)) job
264+ unsafe_delete! (ir, dyn_marker)
265+ end
266+ # # old, deprecated implementation
187267 jobs = Dict {CompilerJob, String} (job => entry_fn)
188- if has_deferred_jobs
268+ if toplevel && ! only_entry && haskey (functions (ir), " deferred_codegen" )
269+ run_optimization_for_deferred = true
189270 dyn_marker = functions (ir)[" deferred_codegen" ]
190271
191272 # iterative compilation (non-recursive)
@@ -194,7 +275,6 @@ const __llvm_initialized = Ref(false)
194275 changed = false
195276
196277 # find deferred compiler
197- # TODO : recover this information earlier, from the Julia IR
198278 worklist = Dict {CompilerJob, Vector{LLVM.CallInst}} ()
199279 for use in uses (dyn_marker)
200280 # decode the call
@@ -317,7 +397,7 @@ const __llvm_initialized = Ref(false)
317397 # deferred codegen has some special optimization requirements,
318398 # which also need to happen _after_ regular optimization.
319399 # XXX : make these part of the optimizer pipeline?
320- if has_deferred_jobs
400+ if run_optimization_for_deferred
321401 @dispose pb= NewPMPassBuilder () begin
322402 add! (pb, NewPMFunctionPassManager ()) do fpm
323403 add! (fpm, InstCombinePass ())
0 commit comments