|
43 | 43 |
|
44 | 44 | function var"gpuc.deferred" end |
45 | 45 |
|
46 | | -# old, deprecated mechanism slated for removal once Enzyme is updated to the new intrinsic |
47 | | -begin |
48 | | - # primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism. |
49 | | - # this could both be generalized (e.g. supporting actual function calls, instead of |
50 | | - # returning a function pointer), and be integrated with the nonrecursive codegen. |
51 | | - const deferred_codegen_jobs = Dict{Int, Any}() |
52 | | - |
53 | | - # We make this function explicitly callable so that we can drive OrcJIT's |
54 | | - # lazy compilation from, while also enabling recursive compilation. |
55 | | - Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid}) |
56 | | - ptr |
57 | | - end |
58 | | - |
59 | | - @generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt} |
60 | | - id = length(deferred_codegen_jobs) + 1 |
61 | | - deferred_codegen_jobs[id] = (; ft, tt) |
62 | | - # don't bother looking up the method instance, as we'll do so again during codegen |
63 | | - # using the world age of the parent. |
64 | | - # |
65 | | - # this also works around an issue on <1.10, where we don't know the world age of |
66 | | - # generated functions so use the current world counter, which may be too new |
67 | | - # for the world we're compiling for. |
68 | | - |
69 | | - quote |
70 | | - # TODO: add an edge to this method instance to support method redefinitions |
71 | | - ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id) |
72 | | - end |
73 | | - end |
74 | | -end |
75 | | - |
76 | | - |
77 | 46 | ## compiler entrypoint |
78 | 47 |
|
79 | 48 | export compile |
@@ -198,7 +167,6 @@ const __llvm_initialized = Ref(false) |
198 | 167 |
|
199 | 168 | # gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the |
200 | 169 | # target method instance from the LLVM IR |
201 | | - # TODO: drive deferred compilation from the Julia IR instead |
202 | 170 | function find_base_object(val) |
203 | 171 | while true |
204 | 172 | if val isa ConstantExpr && (opcode(val) == LLVM.API.LLVMIntToPtr || |
@@ -263,80 +231,6 @@ const __llvm_initialized = Ref(false) |
263 | 231 | @compiler_assert isempty(uses(dyn_marker)) job |
264 | 232 | unsafe_delete!(ir, dyn_marker) |
265 | 233 | end |
266 | | - ## old, deprecated implementation |
267 | | - jobs = Dict{CompilerJob, String}(job => entry_fn) |
268 | | - if toplevel && !only_entry && haskey(functions(ir), "deferred_codegen") |
269 | | - run_optimization_for_deferred = true |
270 | | - dyn_marker = functions(ir)["deferred_codegen"] |
271 | | - |
272 | | - # iterative compilation (non-recursive) |
273 | | - changed = true |
274 | | - while changed |
275 | | - changed = false |
276 | | - |
277 | | - # find deferred compiler |
278 | | - worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}() |
279 | | - for use in uses(dyn_marker) |
280 | | - # decode the call |
281 | | - call = user(use)::LLVM.CallInst |
282 | | - id = convert(Int, first(operands(call))) |
283 | | - |
284 | | - global deferred_codegen_jobs |
285 | | - dyn_val = deferred_codegen_jobs[id] |
286 | | - |
287 | | - # get a job in the appopriate world |
288 | | - dyn_job = if dyn_val isa CompilerJob |
289 | | - # trust that the user knows what they're doing |
290 | | - dyn_val |
291 | | - else |
292 | | - ft, tt = dyn_val |
293 | | - dyn_src = methodinstance(ft, tt, tls_world_age()) |
294 | | - CompilerJob(dyn_src, job.config) |
295 | | - end |
296 | | - |
297 | | - push!(get!(worklist, dyn_job, LLVM.CallInst[]), call) |
298 | | - end |
299 | | - |
300 | | - # compile and link |
301 | | - for dyn_job in keys(worklist) |
302 | | - # cached compilation |
303 | | - dyn_entry_fn = get!(jobs, dyn_job) do |
304 | | - dyn_ir, dyn_meta = codegen(:llvm, dyn_job; toplevel=false, |
305 | | - parent_job=job) |
306 | | - dyn_entry_fn = LLVM.name(dyn_meta.entry) |
307 | | - merge!(compiled, dyn_meta.compiled) |
308 | | - @assert context(dyn_ir) == context(ir) |
309 | | - link!(ir, dyn_ir) |
310 | | - changed = true |
311 | | - dyn_entry_fn |
312 | | - end |
313 | | - dyn_entry = functions(ir)[dyn_entry_fn] |
314 | | - |
315 | | - # insert a pointer to the function everywhere the entry is used |
316 | | - T_ptr = convert(LLVMType, Ptr{Cvoid}) |
317 | | - for call in worklist[dyn_job] |
318 | | - @dispose builder=IRBuilder() begin |
319 | | - position!(builder, call) |
320 | | - fptr = if LLVM.version() >= v"17" |
321 | | - T_ptr = LLVM.PointerType() |
322 | | - bitcast!(builder, dyn_entry, T_ptr) |
323 | | - elseif VERSION >= v"1.12.0-DEV.225" |
324 | | - T_ptr = LLVM.PointerType(LLVM.Int8Type()) |
325 | | - bitcast!(builder, dyn_entry, T_ptr) |
326 | | - else |
327 | | - ptrtoint!(builder, dyn_entry, T_ptr) |
328 | | - end |
329 | | - replace_uses!(call, fptr) |
330 | | - end |
331 | | - unsafe_delete!(LLVM.parent(call), call) |
332 | | - end |
333 | | - end |
334 | | - end |
335 | | - |
336 | | - # all deferred compilations should have been resolved |
337 | | - @compiler_assert isempty(uses(dyn_marker)) job |
338 | | - unsafe_delete!(ir, dyn_marker) |
339 | | - end |
340 | 234 |
|
341 | 235 | if libraries |
342 | 236 | # load the runtime outside of a timing block (because it recurses into the compiler) |
@@ -433,8 +327,8 @@ const __llvm_initialized = Ref(false) |
433 | 327 | # finish the module |
434 | 328 | # |
435 | 329 | # we want to finish the module after optimization, so we cannot do so |
436 | | - # during deferred code generation. instead, process the deferred jobs |
437 | | - # here. |
| 330 | + # during deferred code generation. Instead, process the merged module |
| 331 | + # from all the jobs here. |
438 | 332 | if toplevel |
439 | 333 | entry = finish_ir!(job, ir, entry) |
440 | 334 |
|
|
0 commit comments