@@ -482,8 +482,7 @@ function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecial
482482 nested_codegen! (mode, mod, funcspec, world)
483483end
484484
485- function prepare_llvm (mod:: LLVM.Module , job, meta)
486- interp = GPUCompiler. get_interpreter (job)
485+ function prepare_llvm (interp, mod:: LLVM.Module , job, meta)
487486 for f in functions (mod)
488487 attributes = function_attributes (f)
489488 push! (attributes, StringAttribute (" enzymejl_world" , string (job. world)))
@@ -1001,7 +1000,7 @@ end
10011000 return
10021001end
10031002
1004- function set_module_types! (mod:: LLVM.Module , primalf:: Union{Nothing, LLVM.Function} , job, edges, run_enzyme, mode:: API.CDerivativeMode )
1003+ function set_module_types! (interp, mod:: LLVM.Module , primalf:: Union{Nothing, LLVM.Function} , job, edges, run_enzyme, mode:: API.CDerivativeMode )
10051004
10061005 for f in functions (mod)
10071006 mi, RT = enzyme_custom_extract_mi (f, false )
@@ -1141,7 +1140,6 @@ function set_module_types!(mod::LLVM.Module, primalf::Union{Nothing, LLVM.Functi
11411140 custom = Dict {String,LLVM.API.LLVMLinkage} ()
11421141
11431142 world = job. world
1144- interp = GPUCompiler. get_interpreter (job)
11451143 method_table = Core. Compiler. method_table (interp)
11461144
11471145 state = HandlerState (
@@ -1203,7 +1201,8 @@ function nested_codegen!(
12031201 GPUCompiler. prepare_job! (job)
12041202 otherMod, meta = GPUCompiler. emit_llvm (job)
12051203
1206- prepare_llvm (otherMod, job, meta)
1204+ interp = GPUCompiler. get_interpreter (job)
1205+ prepare_llvm (interp, otherMod, job, meta)
12071206
12081207 entry = name (meta. entry)
12091208
@@ -1221,12 +1220,12 @@ function nested_codegen!(
12211220 LLVM. run! (pm, otherMod)
12221221 end
12231222
1224- check_ir (job, otherMod)
1223+ check_ir (interp, job, otherMod)
12251224
12261225 # Skipped inline of blas
12271226
12281227 run_enzyme = false
1229- set_module_types! (otherMod, nothing , job, edges, run_enzyme, mode)
1228+ set_module_types! (interp, otherMod, nothing , job, edges, run_enzyme, mode)
12301229
12311230 # Apply first stage of optimization's so that this module is at the same stage as `mod`
12321231 optimize! (otherMod, JIT. get_tm ())
@@ -2233,6 +2232,7 @@ const DumpPostWrap = Ref(false)
22332232
22342233function enzyme! (
22352234 job:: CompilerJob ,
2235+ interp,
22362236 mod:: LLVM.Module ,
22372237 primalf:: LLVM.Function ,
22382238 @nospecialize (TT:: Type ),
@@ -2251,7 +2251,6 @@ function enzyme!(
22512251 API. EnzymeDumpModuleRef (mod. ref)
22522252 end
22532253 world = job. world
2254- interp = GPUCompiler. get_interpreter (job)
22552254 rt = job. config. params. rt
22562255 runtimeActivity = job. config. params. runtimeActivity
22572256 strongZero = job. config. params. strongZero
@@ -4290,14 +4289,14 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT
42904289 # ??? entry_abi
42914290 )
42924291 primal_job = CompilerJob (primal, primal_config, job. world)
4293-
42944292 @safe_debug " Emit LLVM with" primal_job
42954293 GPUCompiler. prepare_job! (primal_job)
42964294 mod, meta = GPUCompiler. emit_llvm (primal_job)
42974295 edges = Any[]
42984296 mod_to_edges[mod] = edges
42994297
4300- prepare_llvm (mod, primal_job, meta)
4298+ primal_interp = GPUCompiler. get_interpreter (primal_job)
4299+ prepare_llvm (primal_interp, mod, primal_job, meta)
43014300 for f in functions (mod)
43024301 permit_inlining! (f)
43034302 end
@@ -4311,7 +4310,8 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT
43114310 if DumpPreCheck[]
43124311 API. EnzymeDumpModuleRef (mod. ref)
43134312 end
4314- check_ir (job, mod)
4313+ interp = GPUCompiler. get_interpreter (job)
4314+ check_ir (interp, job, mod)
43154315
43164316 disableFallback = String[]
43174317
@@ -4420,7 +4420,7 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT
44204420 end
44214421 end
44224422
4423- custom, state = set_module_types! (mod, primalf, job, edges, params. run_enzyme, mode)
4423+ custom, state = set_module_types! (interp, mod, primalf, job, edges, params. run_enzyme, mode)
44244424
44254425 primalf = state. primalf
44264426 must_wrap = state. must_wrap
@@ -4899,6 +4899,7 @@ end
48994899
49004900 adjointf, augmented_primalf, TapeType = enzyme! (
49014901 job,
4902+ interp,
49024903 mod,
49034904 primalf,
49044905 TT,
0 commit comments