Skip to content

Commit 701f5f9

Browse files
authored
Fix compile time bug (#2556)
1 parent 3a6776f commit 701f5f9

File tree

3 files changed

+35
-23
lines changed

3 files changed

+35
-23
lines changed

src/compiler.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -482,8 +482,7 @@ function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecial
482482
nested_codegen!(mode, mod, funcspec, world)
483483
end
484484

485-
function prepare_llvm(mod::LLVM.Module, job, meta)
486-
interp = GPUCompiler.get_interpreter(job)
485+
function prepare_llvm(interp, mod::LLVM.Module, job, meta)
487486
for f in functions(mod)
488487
attributes = function_attributes(f)
489488
push!(attributes, StringAttribute("enzymejl_world", string(job.world)))
@@ -1001,7 +1000,7 @@ end
10011000
return
10021001
end
10031002

1004-
function set_module_types!(mod::LLVM.Module, primalf::Union{Nothing, LLVM.Function}, job, edges, run_enzyme, mode::API.CDerivativeMode)
1003+
function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLVM.Function}, job, edges, run_enzyme, mode::API.CDerivativeMode)
10051004

10061005
for f in functions(mod)
10071006
mi, RT = enzyme_custom_extract_mi(f, false)
@@ -1141,7 +1140,6 @@ function set_module_types!(mod::LLVM.Module, primalf::Union{Nothing, LLVM.Functi
11411140
custom = Dict{String,LLVM.API.LLVMLinkage}()
11421141

11431142
world = job.world
1144-
interp = GPUCompiler.get_interpreter(job)
11451143
method_table = Core.Compiler.method_table(interp)
11461144

11471145
state = HandlerState(
@@ -1203,7 +1201,8 @@ function nested_codegen!(
12031201
GPUCompiler.prepare_job!(job)
12041202
otherMod, meta = GPUCompiler.emit_llvm(job)
12051203

1206-
prepare_llvm(otherMod, job, meta)
1204+
interp = GPUCompiler.get_interpreter(job)
1205+
prepare_llvm(interp, otherMod, job, meta)
12071206

12081207
entry = name(meta.entry)
12091208

@@ -1221,12 +1220,12 @@ function nested_codegen!(
12211220
LLVM.run!(pm, otherMod)
12221221
end
12231222

1224-
check_ir(job, otherMod)
1223+
check_ir(interp, job, otherMod)
12251224

12261225
# Skipped inline of blas
12271226

12281227
run_enzyme = false
1229-
set_module_types!(otherMod, nothing, job, edges, run_enzyme, mode)
1228+
set_module_types!(interp, otherMod, nothing, job, edges, run_enzyme, mode)
12301229

12311230
# Apply first stage of optimization's so that this module is at the same stage as `mod`
12321231
optimize!(otherMod, JIT.get_tm())
@@ -2233,6 +2232,7 @@ const DumpPostWrap = Ref(false)
22332232

22342233
function enzyme!(
22352234
job::CompilerJob,
2235+
interp,
22362236
mod::LLVM.Module,
22372237
primalf::LLVM.Function,
22382238
@nospecialize(TT::Type),
@@ -2251,7 +2251,6 @@ function enzyme!(
22512251
API.EnzymeDumpModuleRef(mod.ref)
22522252
end
22532253
world = job.world
2254-
interp = GPUCompiler.get_interpreter(job)
22552254
rt = job.config.params.rt
22562255
runtimeActivity = job.config.params.runtimeActivity
22572256
strongZero = job.config.params.strongZero
@@ -4290,14 +4289,14 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT
42904289
# ??? entry_abi
42914290
)
42924291
primal_job = CompilerJob(primal, primal_config, job.world)
4293-
42944292
@safe_debug "Emit LLVM with" primal_job
42954293
GPUCompiler.prepare_job!(primal_job)
42964294
mod, meta = GPUCompiler.emit_llvm(primal_job)
42974295
edges = Any[]
42984296
mod_to_edges[mod] = edges
42994297

4300-
prepare_llvm(mod, primal_job, meta)
4298+
primal_interp = GPUCompiler.get_interpreter(primal_job)
4299+
prepare_llvm(primal_interp, mod, primal_job, meta)
43014300
for f in functions(mod)
43024301
permit_inlining!(f)
43034302
end
@@ -4311,7 +4310,8 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT
43114310
if DumpPreCheck[]
43124311
API.EnzymeDumpModuleRef(mod.ref)
43134312
end
4314-
check_ir(job, mod)
4313+
interp = GPUCompiler.get_interpreter(job)
4314+
check_ir(interp, job, mod)
43154315

43164316
disableFallback = String[]
43174317

@@ -4420,7 +4420,7 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT
44204420
end
44214421
end
44224422

4423-
custom, state = set_module_types!(mod, primalf, job, edges, params.run_enzyme, mode)
4423+
custom, state = set_module_types!(interp, mod, primalf, job, edges, params.run_enzyme, mode)
44244424

44254425
primalf = state.primalf
44264426
must_wrap = state.must_wrap
@@ -4899,6 +4899,7 @@ end
48994899

49004900
adjointf, augmented_primalf, TapeType = enzyme!(
49014901
job,
4902+
interp,
49024903
mod,
49034904
primalf,
49044905
TT,

src/compiler/interpreter.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,26 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
138138
handler::T
139139
end
140140

141-
141+
const SigCache = Dict{Tuple, Dict{UInt, Base.IdSet{Type}}}()
142142
function get_rule_signatures(f, TT, world)
143+
subdict = if haskey(SigCache, (f, TT))
144+
SigCache[(f, TT)]
145+
else
146+
tmp = Dict{UInt, Base.IdSet{Type}}()
147+
SigCache[(f, TT)] = tmp
148+
tmp
149+
end
150+
if haskey(subdict, world)
151+
return subdict[world]
152+
end
143153
fwdrules_meths = Base._methods(f, TT, -1, world)::Vector
144154
sigs = Type[]
145155
for rule in fwdrules_meths
146156
push!(sigs, (rule::Core.MethodMatch).method.sig)
147157
end
148-
return Base.IdSet{Type}(sigs)
158+
result = Base.IdSet{Type}(sigs)
159+
subdict[world] = result
160+
return result
149161
end
150162

151163
function rule_sigs_equal(a, b)

src/compiler/validation.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,15 @@ function restore_lookups(mod::LLVM.Module)::Nothing
177177
end
178178
end
179179

180-
function check_ir(@nospecialize(job::CompilerJob), mod::LLVM.Module)
181-
errors = check_ir!(job, IRError[], mod)
180+
function check_ir(interp, @nospecialize(job::CompilerJob), mod::LLVM.Module)
181+
errors = check_ir!(interp, job, IRError[], mod)
182182
unique!(errors)
183183
if !isempty(errors)
184184
throw(InvalidIRError(job, errors))
185185
end
186186
end
187187

188-
function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod::LLVM.Module)
188+
function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRError}, mod::LLVM.Module)
189189
imported = Set(String[])
190190
if haskey(functions(mod), "malloc")
191191
f = functions(mod)["malloc"]
@@ -209,7 +209,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod
209209
if in(f, del)
210210
continue
211211
end
212-
check_ir!(job, errors, imported, f, del, mod)
212+
check_ir!(interp, job, errors, imported, f, del, mod)
213213
end
214214
for d in del
215215
LLVM.API.LLVMDeleteFunction(d)
@@ -220,7 +220,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod
220220
if in(f, del)
221221
continue
222222
end
223-
check_ir!(job, errors, imported, f, del, mod)
223+
check_ir!(interp, job, errors, imported, f, del, mod)
224224
end
225225
for d in del
226226
LLVM.API.LLVMDeleteFunction(d)
@@ -229,7 +229,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod
229229
return errors
230230
end
231231

232-
function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, f::LLVM.Function, deletedfns::Vector{LLVM.Function}, mod::LLVM.Module)
232+
function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, f::LLVM.Function, deletedfns::Vector{LLVM.Function}, mod::LLVM.Module)
233233
calls = LLVM.CallInst[]
234234
isInline = API.EnzymeGetCLBool(cglobal((:EnzymeInline, API.libEnzyme))) != 0
235235
mod = LLVM.parent(f)
@@ -466,7 +466,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp
466466

467467
while length(calls) > 0
468468
inst = pop!(calls)
469-
check_ir!(job, errors, imported, inst, calls, mod)
469+
check_ir!(interp, job, errors, imported, inst, calls, mod)
470470
end
471471
return errors
472472
end
@@ -610,9 +610,8 @@ end
610610
import GPUCompiler:
611611
DYNAMIC_CALL, DELAYED_BINDING, RUNTIME_FUNCTION, UNKNOWN_FUNCTION, POINTER_FUNCTION
612612
import GPUCompiler: backtrace, isintrinsic
613-
function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, inst::LLVM.CallInst, calls::Vector{LLVM.CallInst}, mod::LLVM.Module)
613+
function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, inst::LLVM.CallInst, calls::Vector{LLVM.CallInst}, mod::LLVM.Module)
614614
world = job.world
615-
interp = GPUCompiler.get_interpreter(job)
616615
method_table = Core.Compiler.method_table(interp)
617616
bt = backtrace(inst)
618617
dest = called_operand(inst)

0 commit comments

Comments
 (0)