Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -28,6 +30,9 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[weakdeps]
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[sources]
GPUCompiler = {rev = "tb/kernel_state_reference", url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"}

[extensions]
SpecialFunctionsExt = "SpecialFunctions"

Expand All @@ -49,6 +54,8 @@ PrecompileTools = "1"
Preferences = "1"
Printf = "1"
Random = "1"
Random123 = "1.7.1"
RandomNumbers = "1.6.0"
SHA = "0.7"
ScopedValues = "1.3.0"
SpecialFunctions = "2"
Expand Down
1 change: 1 addition & 0 deletions src/Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ include("device/intrinsics/synchronization.jl")
include("device/intrinsics/memory.jl")
include("device/intrinsics/simd.jl")
include("device/intrinsics/atomics.jl")
include("device/intrinsics/random.jl")
include("device/quirks.jl")

# array essentials
Expand Down
70 changes: 70 additions & 0 deletions src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,76 @@ GPUCompiler.runtime_module(::MetalCompilerJob) = Metal

GPUCompiler.method_table(::MetalCompilerJob) = method_table

GPUCompiler.kernel_state_type(job::MetalCompilerJob) = KernelState

function GPUCompiler.finish_module!(@nospecialize(job::MetalCompilerJob),
mod::LLVM.Module, entry::LLVM.Function)
entry = invoke(GPUCompiler.finish_module!,
Tuple{CompilerJob{MetalCompilerTarget}, LLVM.Module, LLVM.Function},
job, mod, entry)

# if this kernel uses our RNG, we should prime the shared state.
# XXX: these transformations should really happen at the Julia IR level...
if haskey(globals(mod), "global_random_keys")
f = initialize_rng_state
ft = typeof(f)
tt = Tuple{}

# don't recurse into `initialize_rng_state()` itself
if job.source.specTypes.parameters[1] == ft
return entry
end

# create a deferred compilation job for `initialize_rng_state()`
src = methodinstance(ft, tt, GPUCompiler.tls_world_age())
cfg = CompilerConfig(job.config; kernel=false, name=nothing)
job = CompilerJob(src, cfg, job.world)
id = length(GPUCompiler.deferred_codegen_jobs) + 1
GPUCompiler.deferred_codegen_jobs[id] = job

# generate IR for calls to `deferred_codegen` and the resulting function pointer
top_bb = first(blocks(entry))
bb = BasicBlock(top_bb, "initialize_rng")
@dispose builder=IRBuilder() begin
position!(builder, bb)
subprogram = LLVM.subprogram(entry)
if subprogram !== nothing
loc = DILocation(0, 0, subprogram)
debuglocation!(builder, loc)
end
debuglocation!(builder, first(instructions(top_bb)))

# call the `deferred_codegen` marker function
T_ptr = if LLVM.version() >= v"17"
LLVM.PointerType()
elseif VERSION >= v"1.12.0-DEV.225"
LLVM.PointerType(LLVM.Int8Type())
else
LLVM.Int64Type()
end
T_id = convert(LLVMType, Int)
deferred_codegen_ft = LLVM.FunctionType(T_ptr, [T_id])
deferred_codegen = if haskey(functions(mod), "deferred_codegen")
functions(mod)["deferred_codegen"]
else
LLVM.Function(mod, "deferred_codegen", deferred_codegen_ft)
end
fptr = call!(builder, deferred_codegen_ft, deferred_codegen, [ConstantInt(id)])

# call the `initialize_rng_state` function
rt = Core.Compiler.return_type(f, tt)
llvm_rt = convert(LLVMType, rt)
llvm_ft = LLVM.FunctionType(llvm_rt)
fptr = inttoptr!(builder, fptr, LLVM.PointerType(llvm_ft))
call!(builder, llvm_ft, fptr)
br!(builder, top_bb)
end

# XXX: put some of the above behind GPUCompiler abstractions
# (e.g., a compile-time version of `deferred_codegen`)
end
return entry
end

function GPUCompiler.finish_ir!(@nospecialize(job::MetalCompilerJob),
mod::LLVM.Module, entry::LLVM.Function)
Expand Down
5 changes: 3 additions & 2 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,10 @@ end
cmdbuf = MTLCommandBuffer(queue)
cmdbuf.label = "MTLCommandBuffer($(nameof(kernel.f)))"
cce = MTLComputeCommandEncoder(cmdbuf)
kernel_state = MtlRefValue(KernelState(make_seed(kernel)))
argument_buffers = try
MTL.set_function!(cce, kernel.pipeline)
bufs = encode_arguments!(cce, kernel, kernel.f, args...)
bufs = encode_arguments!(cce, kernel, kernel.f, kernel_state, args...)
MTL.append_current_function!(cce, groups, threads)
bufs
finally
Expand All @@ -295,7 +296,7 @@ end
# kernel has actually completed.
#
# TODO: is there a way to bind additional resources to the command buffer?
roots = [kernel.f, args]
roots = [kernel.f, kernel_state, args]
MTL.on_completed(cmdbuf) do buf
empty!(roots)
foreach(free, argument_buffers)
Expand Down
Loading