@@ -164,6 +164,8 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob);
164164end
165165
166166# GPUCompiler intrinsic that marks deferred compilation
167+ # In contrast to `deferred_codegen` this doesn't support arbitrary
168+ # jobs as call targets.
167169function var"gpuc.deferred" end
168170
169171# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
@@ -221,12 +223,28 @@ const __llvm_initialized = Ref(false)
221223 # since those modules have been finalized themselves, and we don't want to re-finalize.
222224 entry = finish_module! (job, ir, entry)
223225
226+ function unwrap_constant (val)
227+ while val isa ConstantExpr
228+ if opcode (val) == LLVM. API. LLVMIntToPtr ||
229+ opcode (val) == LLVM. API. LLVMBitCast ||
230+ opcode (val) == LLVM. API. LLVMAddrSpaceCast
231+ val = first (operands (val))
232+ else
233+ break
234+ end
235+ end
236+ return val
237+ end
238+
224239 # deferred code generation
225240 has_deferred_jobs = ! only_entry && toplevel &&
226- haskey (functions (ir), " deferred_codegen" )
241+ (haskey (functions (ir), " deferred_codegen" ) ||
242+ haskey (functions (ir), " gpuc.lookup" ))
243+
227244 jobs = Dict {CompilerJob, String} (job => entry_fn)
228245 if has_deferred_jobs
229- dyn_marker = functions (ir)[" deferred_codegen" ]
246+ dyn_marker = haskey (functions (ir), " deferred_codegen" ) ? functions (ir)[" deferred_codegen" ] : nothing
247+ dyn_marker_v2 = haskey (functions (ir), " gpuc.lookup" ) ? functions (ir)[" gpuc.lookup" ] : nothing
230248
231249 # iterative compilation (non-recursive)
232250 changed = true
@@ -235,26 +253,40 @@ const __llvm_initialized = Ref(false)
235253
236254 # find deferred compiler
237255 # TODO : recover this information earlier, from the Julia IR
256+ # We can do this now with gpuc.lookup
238257 worklist = Dict {CompilerJob, Vector{LLVM.CallInst}} ()
239- for use in uses (dyn_marker)
240- # decode the call
241- call = user (use):: LLVM.CallInst
242- id = convert (Int, first (operands (call)))
243-
244- global deferred_codegen_jobs
245- dyn_val = deferred_codegen_jobs[id]
246-
247- # get a job in the appopriate world
248- dyn_job = if dyn_val isa CompilerJob
249- # trust that the user knows what they're doing
250- dyn_val
251- else
252- ft, tt = dyn_val
253- dyn_src = methodinstance (ft, tt, tls_world_age ())
254- CompilerJob (dyn_src, job. config)
258+ if dyn_marker != = nothing
259+ for use in uses (dyn_marker)
260+ # decode the call
261+ call = user (use):: LLVM.CallInst
262+ id = convert (Int, first (operands (call)))
263+
264+ global deferred_codegen_jobs
265+ dyn_val = deferred_codegen_jobs[id]
266+
267+ # get a job in the appopriate world
268+ dyn_job = if dyn_val isa CompilerJob
269+ # trust that the user knows what they're doing
270+ dyn_val
271+ else
272+ ft, tt = dyn_val
273+ dyn_src = methodinstance (ft, tt, tls_world_age ())
274+ CompilerJob (dyn_src, job. config)
275+ end
276+
277+ push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
255278 end
279+ end
256280
257- push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
281+ if dyn_marker_v2 != = nothing
282+ for use in uses (dyn_marker_v2)
283+ # decode the call
284+ call = user (use):: LLVM.CallInst
285+ dyn_mi = Base. unsafe_pointer_to_objref (
286+ convert (Ptr{Cvoid}, convert (Int, unwrap_constant (operands (call)[1 ]))))
287+ dyn_job = CompilerJob (dyn_mi, job. config)
288+ push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
289+ end
258290 end
259291
260292 # compile and link
@@ -296,8 +328,15 @@ const __llvm_initialized = Ref(false)
296328 end
297329
298330 # all deferred compilations should have been resolved
299- @compiler_assert isempty (uses (dyn_marker)) job
300- unsafe_delete! (ir, dyn_marker)
331+ if dyn_marker != = nothing
332+ @compiler_assert isempty (uses (dyn_marker)) job
333+ unsafe_delete! (ir, dyn_marker)
334+ end
335+
336+ if dyn_marker_v2 != = nothing
337+ @compiler_assert isempty (uses (dyn_marker_v2)) job
338+ unsafe_delete! (ir, dyn_marker_v2)
339+ end
301340 end
302341
303342 if toplevel
0 commit comments