Skip to content

Commit 10b2caa

Browse files
authored
Separated out memory instance rewrite pass (#2561)
* Separated out memory instance rewrite pass * try ci * fix * fix * fix
1 parent 89cb0e7 commit 10b2caa

File tree

5 files changed

+74
-2
lines changed

5 files changed

+74
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5"
4141
CEnum = "0.4, 0.5"
4242
ChainRulesCore = "1"
4343
EnzymeCore = "0.8.13"
44-
Enzyme_jll = "0.0.195"
44+
Enzyme_jll = "0.0.196"
4545
GPUArraysCore = "0.1.6, 0.2"
4646
GPUCompiler = "1.6"
4747
LLVM = "6.1, 7, 8, 9"

src/absint.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,9 @@ function should_recurse(@nospecialize(typ2), @nospecialize(arg_t::LLVM.LLVMType)
255255
end
256256
end
257257

258-
function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Bool = true, inttoptr::Bool = false)::Tuple{LLVM.Value, Int}
258+
function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Bool = true, inttoptr::Bool = false, inst::Union{LLVM.Instruction, Nothing} = nothing)::Tuple{LLVM.Value, Int}
259259
offset = 0
260+
pinst = isa(larg, LLVM.Instruction) ? larg::LLVM.Instruction : inst
260261
while true
261262
if isa(larg, LLVM.ConstantExpr)
262263
if opcode(larg) == LLVM.API.LLVMBitCast || opcode(larg) == LLVM.API.LLVMAddrSpaceCast || opcode(larg) == LLVM.API.LLVMPtrToInt
@@ -267,6 +268,24 @@ function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Boo
267268
larg = operands(larg)[1]
268269
continue
269270
end
271+
if opcode(larg) == LLVM.API.LLVMGetElementPtr && pinst isa LLVM.Instruction
272+
b = LLVM.IRBuilder()
273+
position!(b, pinst)
274+
offty = LLVM.IntType(8 * sizeof(Int))
275+
offset2 = API.EnzymeComputeByteOffsetOfGEP(b, larg, offty)
276+
if isa(offset2, LLVM.ConstantInt)
277+
val = convert(Int, offset2)
278+
if offsetAllowed || val == 0
279+
offset += val
280+
larg = operands(larg)[1]
281+
continue
282+
else
283+
break
284+
end
285+
else
286+
break
287+
end
288+
end
270289
end
271290
if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.AddrSpaceCastInst) || isa(larg, LLVM.IntToPtrInst)
272291
larg = operands(larg)[1]

src/compiler/optimize.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,8 @@ function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
488488
gvn!(pm) # Extra
489489
LLVM.run!(pm, mod)
490490
end
491+
492+
rewrite_generic_memory!(mod)
491493

492494
ModulePassManager() do pm
493495
add_library_info!(pm, triple(mod))

src/errors.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,30 @@ function julia_error(
524524
return seen[cur]
525525
end
526526

527+
@static if VERSION < v"1.11-"
528+
else
529+
if isa(cur, LLVM.LoadInst) && isa(value_type(cur), LLVM.PointerType) && LLVM.addrspace(value_type(operands(cur)[1])) == Derived
530+
larg, off = get_base_and_offset(operands(cur)[1]; inst=ncur, inttoptr=true)
531+
if isa(larg, LLVM.ConstantInt) && off == sizeof(Int)
532+
ptr = reinterpret(Ptr{Cvoid}, convert(UInt, larg))
533+
obj = Base.unsafe_pointer_to_objref(ptr)
534+
if obj isa Memory && obj == typeof(obj).instance
535+
return make_batched(ncur, prevbb)
536+
end
537+
end
538+
end
539+
if isa(cur, LLVM.ConstantExpr) && isa(value_type(cur), LLVM.PointerType) && LLVM.addrspace(value_type(cur)) == Derived
540+
larg, off = get_base_and_offset(cur; inst=first(instructions(position(prevbb))), inttoptr=true)
541+
if isa(larg, LLVM.ConstantInt) && (off == sizeof(Int) || off == 0)
542+
ptr = reinterpret(Ptr{Cvoid}, convert(UInt, larg))
543+
obj = Base.unsafe_pointer_to_objref(ptr)
544+
if obj isa Memory && obj == typeof(obj).instance
545+
return make_batched(ncur, prevbb)
546+
end
547+
end
548+
end
549+
end
550+
527551
legal, TT, byref = abs_typeof(cur, true)
528552

529553
if legal

src/llvm/transforms.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2297,6 +2297,33 @@ function checkNoAssumeFalse(mod::LLVM.Module, shouldshow::Bool = false)
22972297
end
22982298
end
22992299

2300+
function rewrite_generic_memory!(mod::LLVM.Module)
2301+
@static if VERSION < v"1.11-"
2302+
else
2303+
for f in functions(mod), bb in blocks(f)
2304+
iter = LLVM.API.LLVMGetFirstInstruction(bb)
2305+
while iter != C_NULL
2306+
inst = LLVM.Instruction(iter)
2307+
iter = LLVM.API.LLVMGetNextInstruction(iter)
2308+
if !isa(inst, LLVM.LoadInst)
2309+
continue
2310+
end
2311+
2312+
if isa(operands(inst)[1], LLVM.ConstantExpr)
2313+
legal2, obj = absint(inst)
2314+
if legal2 && obj isa Memory && obj == typeof(obj).instance
2315+
b = LLVM.IRBuilder()
2316+
position!(b, inst)
2317+
replace_uses!(inst, unsafe_to_llvm(b, obj))
2318+
LLVM.API.LLVMInstructionEraseFromParent(inst)
2319+
continue
2320+
end
2321+
end
2322+
end
2323+
end
2324+
end
2325+
end
2326+
23002327
function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine)
23012328
# We need to run globalopt first. This is because remove dead args will otherwise
23022329
# take internal functions and replace their args with undef. Then on LLVM up to

0 commit comments

Comments
 (0)