Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions src/spirv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ function validate_ir(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module)
append!(errors, check_ir_values(mod, LLVM.DoubleType()))
end

# SPIR-V never supports 128-bit integers, but we have a legalization pass
# Warn if 128-bit integers are detected (they will be legalized to 64-bit)
i128_uses = check_ir_values(mod, LLVM.IntType(128))
if !isempty(i128_uses)
@safe_warn "Found 128-bit integer operations in SPIR-V kernel; these will be legalized to 64-bit integers, which may cause precision loss or incorrect results for values outside the Int64 range. Use JULIA_DEBUG=GPUCompiler for more details."

end

return errors
end

Expand Down Expand Up @@ -136,6 +144,9 @@ end
# (SPIRV-LLVM-Translator#1140)
rm_freeze!(mod)

# SPIR-V does not support 128-bit integers
legalize_int128!(mod)

# translate to SPIR-V
input = tempname(cleanup=false) * ".bc"
translated = tempname(cleanup=false) * ".spv"
Expand Down Expand Up @@ -283,6 +294,183 @@ function rm_freeze!(mod::LLVM.Module)
return changed
end

# legalize 128-bit integers by replacing them with pairs of 64-bit integers
function legalize_int128!(mod::LLVM.Module)
job = current_job::CompilerJob
changed = false
@tracepoint "legalize int128" begin

i128 = LLVM.IntType(128)
i64 = LLVM.IntType(64)
ctx = context(mod)

# Create a struct type to replace i128: {i64, i64}
i128_replacement = LLVM.StructType([i64, i64])

# Process all functions
for f in functions(mod)
worklist = Vector{LLVM.Instruction}()

# Collect instructions that use i128
for bb in blocks(f), inst in instructions(bb)
# Check if instruction result is i128
if value_type(inst) == i128
push!(worklist, inst)
else
# Check if any operand is i128
for op in operands(inst)
if value_type(op) == i128
push!(worklist, inst)
break
end
end
end
end

if !isempty(worklist)
@safe_debug "Legalizing $(length(worklist)) i128 instruction(s) in function $(LLVM.name(f))"
end

# Process instructions that need legalization
@dispose builder = IRBuilder() begin
for inst in worklist
position!(builder, inst)

@safe_debug "Legalizing i128 instruction: $(string(inst))"

# Handle different instruction types
if inst isa LLVM.LoadInst && value_type(inst) == i128
@safe_debug " Converting i128 load to i64 (low bits only)"
# Load i128 -> Load {i64, i64}
ptr = operands(inst)[1]
new_ptr = bitcast!(builder, ptr, LLVM.PointerType(i128_replacement))
new_load = load!(builder, i128_replacement, new_ptr)

# For now, we'll just use the low 64 bits
# This is a simplification - proper implementation would need to handle all uses
lo = extract_value!(builder, new_load, 0)
replace_uses!(inst, lo)
erase!(inst)
changed = true

elseif inst isa LLVM.StoreInst
val, ptr = operands(inst)
if value_type(val) == i128
@safe_debug " Converting i128 store to i64 (high bits zeroed)"
# Store i128 -> Store {i64, i64}
# Create a struct with val in low part, 0 in high part
undef_struct = undef_value(i128_replacement)
struct_val = insert_value!(builder, undef_struct, val, 0)
zero_i64 = LLVM.ConstantInt(i64, 0)
struct_val = insert_value!(builder, struct_val, zero_i64, 1)

new_ptr = bitcast!(builder, ptr, LLVM.PointerType(i128_replacement))
store!(builder, struct_val, new_ptr)
erase!(inst)
changed = true
end

elseif inst isa LLVM.TruncInst && value_type(inst) == i128
@safe_debug " Converting truncation to i128 -> truncation to i64"
# Truncation to i128 - just use the source value truncated to i64
src = operands(inst)[1]
new_trunc = trunc!(builder, src, i64)
replace_uses!(inst, new_trunc)
erase!(inst)
changed = true

elseif inst isa LLVM.ZExtInst && value_type(inst) == i128
@safe_debug " Converting zero extension to i128 -> zero extension to i64"
# Zero extension to i128 - just extend to i64 instead
src = operands(inst)[1]
new_zext = zext!(builder, src, i64)
replace_uses!(inst, new_zext)
erase!(inst)
changed = true

elseif inst isa LLVM.SExtInst && value_type(inst) == i128
@safe_debug " Converting sign extension to i128 -> sign extension to i64"
# Sign extension to i128 - just extend to i64 instead
src = operands(inst)[1]
new_sext = sext!(builder, src, i64)
replace_uses!(inst, new_sext)
erase!(inst)
changed = true

elseif inst isa LLVM.AddInst && value_type(inst) == i128
@safe_debug " Converting i128 addition to i64 (truncating to low 64 bits)"
# Add i128 -> Add i64 (truncate operands to low 64 bits)
# This is correct for values that fit in i64 range (common for indexing)
ops = operands(inst)
lhs_val = ops[1]
rhs_val = ops[2]

# Truncate to get low 64 bits
lhs_lo = value_type(lhs_val) == i128 ? trunc!(builder, lhs_val, i64) : lhs_val
rhs_lo = value_type(rhs_val) == i128 ? trunc!(builder, rhs_val, i64) : rhs_val

# Add low parts
sum_lo = add!(builder, lhs_lo, rhs_lo)

replace_uses!(inst, sum_lo)
erase!(inst)
changed = true

elseif inst isa LLVM.MulInst && value_type(inst) == i128
@safe_debug " Converting i128 multiplication to i64 (truncating to low 64 bits)"
# Mul i128 -> Mul i64 (truncate operands to low 64 bits)
# Note: This only gives correct low 64 bits of the product
ops = operands(inst)
lhs_val = ops[1]
rhs_val = ops[2]

# Truncate to get low 64 bits
lhs_lo = value_type(lhs_val) == i128 ? trunc!(builder, lhs_val, i64) : lhs_val
rhs_lo = value_type(rhs_val) == i128 ? trunc!(builder, rhs_val, i64) : rhs_val

# Multiply low parts
prod_lo = mul!(builder, lhs_lo, rhs_lo)

replace_uses!(inst, prod_lo)
erase!(inst)
changed = true

elseif inst isa LLVM.ICmpInst
# ICmp with i128 operands -> compare using low 64 bits only
# Note: by the time we process this, operands may already be legalized
ops = collect(operands(inst))
if any(op -> value_type(op) == i128, ops)
pred = LLVM.predicate(inst)
@safe_debug " Converting i128 comparison (predicate: $pred) using low 64 bits"

lhs_val = ops[1]
rhs_val = ops[2]

# Truncate to low 64 bits
lhs_lo = value_type(lhs_val) == i128 ? trunc!(builder, lhs_val, i64) : lhs_val
rhs_lo = value_type(rhs_val) == i128 ? trunc!(builder, rhs_val, i64) : rhs_val

# Compare low bits only
# This is correct for values that fit in i64 range
result = icmp!(builder, pred, lhs_lo, rhs_lo)

replace_uses!(inst, result)
erase!(inst)
changed = true
# else: operands were already legalized by earlier instructions, nothing to do
end

else
@safe_warn " Unhandled i128 instruction type: $(typeof(inst))"
end
end
end
end

end
return changed
end

# wrap byval pointers in a single-value struct
function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
ft = function_type(f)::LLVM.FunctionType
Expand Down