Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/analysis/ADAnalyzer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ end
return AnalyzedSource(ir, slotnames, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva)
end

@override function Compiler.transform_result_for_local_cache(interp::ADAnalyzer, result::InferenceResult)
@override function Compiler.transform_result_for_local_cache(interp::ADAnalyzer, result::InferenceResult, edges::SimpleVector)
if Compiler.result_is_constabi(interp, result)
return nothing
end
Expand Down
11 changes: 3 additions & 8 deletions src/transform/codegen/dae_factory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end

const SCIML_ABI = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64}

function sciml_to_internal_abi!(ir::IRCode, state::TransformationState, internal_ci::CodeInstance, key::TornCacheKey, var_eq_matching, settings::Settings)
function sciml_to_internal_abi!(ir::IRCode, state::TransformationState, internal_ci::CodeInstance, key::TornCacheKey, var_eq_matching, world::UInt, settings::Settings)
(; result, structure) = state

numstates = zeros(Int, Int(LastEquationStateKind))
Expand Down Expand Up @@ -111,12 +111,7 @@ function sciml_to_internal_abi!(ir::IRCode, state::TransformationState, internal
resize!(ir.cfg.blocks, 1)
empty!(ir.cfg.blocks[1].succs)
Compiler.verify_ir(ir)

@async @eval Main begin
interface_ir = $ir
end

return Core.OpaqueClosure(ir; slotnames = [:captures, :out, :du, :u, :p, :t])
return optimized_opaque_closure(ir, world; slotnames = [:captures, :out, :du, :u, :p, :t])
end

"""
Expand Down Expand Up @@ -173,7 +168,7 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Unio
end

daef_ci = rhs_finish!(state, ci, key, world, settings, 1)
oc = sciml_to_internal_abi!(copy(ci.inferred.ir), state, daef_ci, key, var_eq_matching, settings)
oc = sciml_to_internal_abi!(copy(ci.inferred.ir), state, daef_ci, key, var_eq_matching, world, settings)
end

line = result.ir[SSAValue(1)][:line]
Expand Down
2 changes: 1 addition & 1 deletion src/transform/codegen/init_factory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ function init_uncompress_gen!(compact::Compiler.IncrementalCompact, result::DAEI
@insert_instruction_here(oc_compact, line, settings, (return out_arr)::Vector{Float64})

ir_oc = Compiler.finish(oc_compact)
oc = Core.OpaqueClosure(ir_oc)
oc = optimized_opaque_closure(ir_oc, world)

line = result.ir[SSAValue(1)][:line]

Expand Down
2 changes: 1 addition & 1 deletion src/transform/codegen/ode_factory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ function ode_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
interface_ir = Compiler.finish(interface_ic)
maybe_rewrite_debuginfo!(interface_ir, settings)
Compiler.verify_ir(interface_ir)
interface_oc = Core.OpaqueClosure(interface_ir; slotnames = [:self, :du, :u, :p, :t])
interface_oc = optimized_opaque_closure(interface_ir, world; slotnames = [:self, :du, :u, :p, :t])

line = result.ir[SSAValue(1)][:line]

Expand Down
38 changes: 38 additions & 0 deletions src/transform/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,44 @@ function cache_dae_ci!(old_ci, src, debuginfo, abi, owner; rettype=Tuple)
return daef_ci
end

function optimized_opaque_closure(ir::IRCode, world::UInt; slotnames = nothing)
oc = Core.OpaqueClosure(ir)
adjust_world_bounds!(oc)
optimized_oc = optimize_opaque_closure!(oc, world; slotnames)
adjust_world_bounds!(optimized_oc)
return optimized_oc
end

function optimize_opaque_closure!(oc::Core.OpaqueClosure, world::UInt; slotnames = nothing)
method = oc.source
ci = method.specializations.cache
ir = reinfer_and_inline(ci, world)
return Core.OpaqueClosure(ir; slotnames)
end

# Not sure if/why this is necessary or even correct, but
# otherwise the `CodeInstance` bounds are outdated.
function adjust_world_bounds!(oc::Core.OpaqueClosure)
ci = oc.source.specializations.cache
@atomic ci.min_world = ci.inferred.min_world
@atomic ci.max_world = ci.inferred.max_world
end

function reinfer_and_inline(ci::CodeInstance, world::UInt)
interp = Compiler.NativeInterpreter(world)
mi = Compiler.get_ci_mi(ci)
argtypes = collect(Any, mi.specTypes.parameters)
irsv = Compiler.IRInterpretationState(interp, ci, mi, argtypes, world)
@assert irsv !== nothing
for stmt in irsv.ir.stmts
stmt[:flag] |= Compiler.IR_FLAG_REFINED
end
Compiler.ir_abstract_constant_propagation(interp, irsv)
state = Compiler.InliningState(interp)
ir = Compiler.ssa_inlining_pass!(irsv.ir, state, Compiler.propagate_inbounds(irsv))
return ir
end

function replace_call!(ir::Union{IRCode,IncrementalCompact}, idx::SSAValue, @nospecialize(new_call), settings::Settings, source)
replace_call!(ir, idx, new_call)
settings.insert_stmt_debuginfo || return new_call
Expand Down
Loading