Skip to content

Commit 2bd31a0

Browse files
authored
Refactor inlining to allow re-use in more sophisticated inlining passes (#37027)
The inlining transform basically has three parts: 1. Analysis (What needs to be inlined and are we allowed to do that?) 2. Policy (Should we inline this?) 3. Mechanism (Stuff the bits from one function into the other) At the moment, we already separate this out into two passes: Analysis/Policy (assemble_inline_todo!) and Mechanism (batch_inline!). For our needs in base, the policy bits are quite simple (how large is the optimized version of this function), but that policy is insufficient for some more sophisticated inlining needs I have in an external compiler pass (where I want to interleave inlining with different transforms as well as potentially run inlining multiple times). To facilitate such use cases, this commit optionally splits out the policy part, but lets the analysis and mechanism parts be re-used by a more sophisticated inlining pass. It also refactors the optimization state to more clearly delineate the different independent parts (edge tracking, inference catches, method table), as well as making the different parts optional (where not required). We were already essentially supporting optimization without edge tracking (for testing purposes), so this is just a bit more explicit about it (which is useful for me, since the different inlining passes in my pipeline may need different settings). For base itself, nothing should functionally change, though hopefully things are factored a bit cleaner.
1 parent 840e2fc commit 2bd31a0

File tree

7 files changed

+283
-226
lines changed

7 files changed

+283
-226
lines changed

base/compiler/optimize.jl

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,62 @@
44
# OptimizationState #
55
#####################
66

7-
mutable struct OptimizationState
7+
struct EdgeTracker
8+
edges::Vector{Any}
9+
valid_worlds::RefValue{WorldRange}
10+
EdgeTracker(edges::Vector{Any}, range::WorldRange) =
11+
new(edges, RefValue{WorldRange}(range))
12+
end
13+
EdgeTracker() = EdgeTracker(Any[], 0:typemax(UInt))
14+
15+
intersect!(et::EdgeTracker, range::WorldRange) =
16+
et.valid_worlds[] = intersect(et.valid_worlds[], range)
17+
18+
push!(et::EdgeTracker, mi::MethodInstance) = push!(et.edges, mi)
19+
function push!(et::EdgeTracker, ci::CodeInstance)
20+
intersect!(et, WorldRange(min_world(li), max_world(li)))
21+
push!(et, ci.def)
22+
end
23+
24+
struct InferenceCaches{T, S}
25+
inf_cache::T
26+
mi_cache::S
27+
end
28+
29+
struct InliningState{S <: Union{EdgeTracker, Nothing}, T <: Union{InferenceCaches, Nothing}, V <: Union{Nothing, MethodTableView}}
830
params::OptimizationParams
31+
et::S
32+
caches::T
33+
method_table::V
34+
end
35+
36+
mutable struct OptimizationState
937
linfo::MethodInstance
10-
calledges::Vector{Any}
1138
src::CodeInfo
1239
stmt_info::Vector{Any}
1340
mod::Module
1441
nargs::Int
15-
world::UInt
16-
valid_worlds::WorldRange
1742
sptypes::Vector{Any} # static parameters
1843
slottypes::Vector{Any}
1944
const_api::Bool
20-
# TODO: This will be eliminated once optimization no longer needs to do method lookups
21-
interp::AbstractInterpreter
45+
inlining::InliningState
2246
function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
2347
s_edges = frame.stmt_edges[1]
2448
if s_edges === nothing
2549
s_edges = []
2650
frame.stmt_edges[1] = s_edges
2751
end
2852
src = frame.src
29-
return new(params, frame.linfo,
30-
s_edges::Vector{Any},
53+
inlining = InliningState(params,
54+
EdgeTracker(s_edges::Vector{Any}, frame.valid_worlds),
55+
InferenceCaches(
56+
get_inference_cache(interp),
57+
WorldView(code_cache(interp), frame.world)),
58+
method_table(interp))
59+
return new(frame.linfo,
3160
src, frame.stmt_info, frame.mod, frame.nargs,
32-
frame.world, frame.valid_worlds,
3361
frame.sptypes, frame.slottypes, false,
34-
interp)
62+
inlining)
3563
end
3664
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
3765
# prepare src for running optimization passes
@@ -45,7 +73,6 @@ mutable struct OptimizationState
4573
if slottypes === nothing
4674
slottypes = Any[ Any for i = 1:nslots ]
4775
end
48-
s_edges = []
4976
stmt_info = Any[nothing for i = 1:nssavalues]
5077
# cache some useful state computations
5178
toplevel = !isa(linfo.def, Method)
@@ -57,12 +84,18 @@ mutable struct OptimizationState
5784
inmodule = linfo.def::Module
5885
nargs = 0
5986
end
60-
return new(params, linfo,
61-
s_edges::Vector{Any},
87+
# Allow using the global MI cache, but don't track edges.
88+
# This method is mostly used for unit testing the optimizer
89+
inlining = InliningState(params,
90+
nothing,
91+
InferenceCaches(
92+
get_inference_cache(interp),
93+
WorldView(code_cache(interp), get_world_counter())),
94+
method_table(interp))
95+
return new(linfo,
6296
src, stmt_info, inmodule, nargs,
63-
get_world_counter(), WorldRange(UInt(1), get_world_counter()),
6497
sptypes_from_meth_instance(linfo), slottypes, false,
65-
interp)
98+
inlining)
6699
end
67100
end
68101

@@ -106,25 +139,6 @@ const TOP_TUPLE = GlobalRef(Core, :tuple)
106139

107140
_topmod(sv::OptimizationState) = _topmod(sv.mod)
108141

109-
function update_valid_age!(sv::OptimizationState, valid_worlds::WorldRange)
110-
sv.valid_worlds = intersect(sv.valid_worlds, valid_worlds)
111-
@assert(sv.world in sv.valid_worlds, "invalid age range update")
112-
nothing
113-
end
114-
115-
function add_backedge!(li::MethodInstance, caller::OptimizationState)
116-
#TODO: deprecate this?
117-
isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs
118-
push!(caller.calledges, li)
119-
nothing
120-
end
121-
122-
function add_backedge!(li::CodeInstance, caller::OptimizationState)
123-
update_valid_age!(caller, WorldRange(min_world(li), max_world(li)))
124-
add_backedge!(li.def, caller)
125-
nothing
126-
end
127-
128142
function isinlineable(m::Method, me::OptimizationState, params::OptimizationParams, union_penalties::Bool, bonus::Int=0)
129143
# compute the cost (size) of inlining this code
130144
inlineable = false

base/compiler/ssair/driver.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ function run_passes(ci::CodeInfo, nargs::Int, sv::OptimizationState)
124124
#@Base.show ("after_construct", ir)
125125
# TODO: Domsorting can produce an updated domtree - no need to recompute here
126126
@timeit "compact 1" ir = compact!(ir)
127-
@timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv)
127+
@timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds)
128128
#@timeit "verify 2" verify_ir(ir)
129129
ir = compact!(ir)
130130
#@Base.show ("before_sroa", ir)

0 commit comments

Comments
 (0)