Skip to content

Commit 3a0e855

Browse files
committed
wip
1 parent 07b0634 commit 3a0e855

File tree

5 files changed

+22
-3
lines changed

5 files changed

+22
-3
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
3030
[weakdeps]
3131
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3232

33+
[sources]
34+
GPUCompiler = {rev = "tb/kernel_state_reference", url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"}
35+
3336
[extensions]
3437
SpecialFunctionsExt = "SpecialFunctions"
3538

src/compiler/compilation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ GPUCompiler.runtime_module(::MetalCompilerJob) = Metal
88

99
GPUCompiler.method_table(::MetalCompilerJob) = method_table
1010

11-
GPUCompiler.kernel_state_type(job::MetalCompilerJob) = MtlRefValue{KernelState}
11+
GPUCompiler.kernel_state_type(job::MetalCompilerJob) = KernelState
1212

1313
function GPUCompiler.finish_module!(@nospecialize(job::MetalCompilerJob),
1414
mod::LLVM.Module, entry::LLVM.Function)

src/device/runtime.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,4 @@ struct KernelState
3939
random_seed::UInt32
4040
end
4141

42-
#@inline @generated kernel_state() = GPUCompiler.kernel_state_value(KernelState)
43-
@inline @generated kernel_state() = GPUCompiler.kernel_state_value(MtlRefValue{KernelState})[]
42+
@inline @generated kernel_state() = GPUCompiler.kernel_state_value(KernelState)

test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
55
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
66
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
77
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
8+
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
89
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
910
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1011
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
@@ -20,3 +21,6 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2021
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2122
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2223
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
24+
25+
[sources]
26+
GPUCompiler = {rev = "tb/kernel_state_reference", url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"}

test/device/random.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
@testset "device-side rng" begin
2+
function rand_kernel!(a)
3+
i = thread_position_in_grid_1d()
4+
a[i] = rand(Float32)
5+
return
6+
end
7+
8+
n = 128
9+
a = Metal.fill!(-1f0, n)
10+
@metal threads=n rand_kernel!(a)
11+
@test all(0 .<= a .< 1)
12+
@test length(unique(Array(a))) == n
13+
end

0 commit comments

Comments
 (0)