Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4410,7 +4410,7 @@ function lower_convention(
println(io, string(mod))
println(
io,
LVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction),
LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction),
)
println(io, string(wrapper_f))
println(io, "Broken function")
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Clang_jll = "0ee61d77-7f21-5576-8119-9fcc46b10100"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
Expand All @@ -12,6 +13,7 @@ InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc"
Expand Down
93 changes: 93 additions & 0 deletions test/embedded_bitcode.jl
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I seem to remember LLVM IR isn't very portable across major versions, I presume that'd be a potential problem here when testing different Julia (and llvm) versions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you, and I also didn't feel this was the best test case.

The only way I could reproduce this error was with ccall inside a wrapper function. On the contrary, if I usellvmcall, I could get the gradient correctly. When I inspected the LLVM IR when passing a view of an array, Julia inlines the wrapper function with weakly typed subarray pointers, hence the mismatch at the callsite and the inner ccall.

Also, have a look at this comment in the original issue.

Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
using Enzyme
using Clang_jll
using Libdl
using Test

const FUNC_LLVM_IR = """
declare double @llvm.rint.f64(double) #1

define i32 @func(double* noalias nocapture writeonly %retptr, { i8*, i32, i8*, i8*, i32 }** noalias nocapture readnone %excinfo, double %arg.t, i8* nocapture readnone %arg.arr.0, i8* nocapture readnone %arg.arr.1, i64 %arg.arr.2, i64 %arg.arr.3, double* %arg.arr.4, i64 %arg.arr.5.0, i64 %arg.arr.6.0) local_unnamed_addr #0 {
common.ret:
%.27 = fdiv double %arg.t, 1.000000e-02
%.28 = tail call double @llvm.rint.f64(double %.27)
%.29 = fptosi double %.28 to i64
%.42 = icmp slt i64 %.29, 0
%.43 = select i1 %.42, i64 %arg.arr.5.0, i64 0
%.44 = add i64 %.43, %.29
%.55 = mul i64 %.44, %arg.arr.6.0
%.56 = ptrtoint double* %arg.arr.4 to i64
%.57 = add i64 %.55, %.56
%.58 = inttoptr i64 %.57 to double*
%.59 = load double, double* %.58, align 8
store double %.59, double* %retptr, align 8
ret i32 0
}

define double @func_wrap({ i8*, i32, i8*, i8*, i32 }** %excinfo, double %arg.t, i8* %arg.arr.0, i8* %arg.arr.1, i64 %arg.arr.2, i64 %arg.arr.3, double* %arg.arr.4, i64 %arg.arr.5.0, i64 %arg.arr.6.0) {
entry:
%tmp = alloca double, align 8
%st = call i32 @func(double* %tmp, { i8*, i32, i8*, i8*, i32 }** %excinfo, double %arg.t, i8* %arg.arr.0, i8* %arg.arr.1, i64 %arg.arr.2, i64 %arg.arr.3, double* %arg.arr.4, i64 %arg.arr.5.0, i64 %arg.arr.6.0)
%val = load double, double* %tmp, align 8
ret double %val
}


attributes #0 = { mustprogress nofree nosync nounwind willreturn }
attributes #1 = { mustprogress nocallback nofree nosync nounwind readnone speculatable willreturn }
attributes #2 = { noinline }
"""


tmp_dir = tempdir()
tmp_so_file = joinpath(tmp_dir, "func.so")
run(
pipeline(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this even necessary here?

if you're starting from llvm anyways, why not just use llvmcall?

Copy link
Contributor Author

@ymardoukhi ymardoukhi Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surprisingly, I won't hit that execution path when using llvmcall. Have a look at this comment here
#2664 (comment)

The wrapper function around llvmcall handles the SubArray input with no issues, whereas the same wrapper function around ccall will fail to compute the gradient. I stopped short of investigating the core issue; the only clue I got from the emitted LLVM IR is that Julia inlines the wrapper function, which leads to a mismatch between the types at the callsite.

`$(clang()) -x ir - -Xclang -no-opaque-pointers -O3 -fPIC -fembed-bitcode -shared -o $(tmp_so_file)`;
stdin=IOBuffer(FUNC_LLVM_IR)
)
)

lib = Libdl.dlopen(tmp_so_file)
const fptr = Libdl.dlsym(lib, :func_wrap)


function func_ccall(t::Float64, arr::AbstractVector{Float64})
nitems = length(arr)
bitsize = Base.elsize(arr)
GC.@preserve arr begin
excinfo = Ptr{Ptr{Cvoid}}(C_NULL)
base::Ptr{Cdouble} = pointer(arr)

ccall(fptr, Cdouble,
(Ptr{Ptr{Cvoid}}, Cdouble, Ptr{Cvoid}, Ptr{Cvoid},
Clong, Clong, Ptr{Cdouble}, Clong, Clong),
excinfo, t, C_NULL, C_NULL, nitems, bitsize,
base, nitems, nitems * bitsize)
end
end

@testset "Broken Function ccall + @view" begin
a = rand(10)
expected_grad_a = (nothing, [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
grad_a = gradient(Reverse, func_ccall, Const(0.0), a)
@test expected_grad_a == grad_a


errstream = joinpath(tempdir(), "stdout.txt")
err_llvmir = nothing
b = @view a[1:5]

redirect_stdio(stdout=errstream, stderr=errstream, stdin=devnull) do
try
gradient(Reverse, func_ccall, Const(0.0), b)
catch e
err_llvmir = e
end

@test err_llvmir !== nothing
@test occursin("Broken function", err_llvmir.info)
end

errtxt = read(errstream, String)
@test occursin("Called function is not the same type as the call!", errtxt)
end
Loading