Skip to content

Commit a26e0df

Browse files
Firestar99eddyb
authored andcommitted
compiletest: shared memory reductions, using the same type in buffers and as shared memory
1 parent 13d851d commit a26e0df

File tree

4 files changed

+242
-0
lines changed

4 files changed

+242
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// build-pass
2+
3+
use core::ops::{Add, AddAssign, Deref, DerefMut};
4+
use spirv_std::arch::workgroup_memory_barrier_with_group_sync;
5+
use spirv_std::glam::*;
6+
use spirv_std::spirv;
7+
8+
#[derive(Copy, Clone, Debug)]
9+
pub struct Value(pub [f32; 4]);
10+
11+
impl Deref for Value {
12+
type Target = [f32; 4];
13+
14+
fn deref(&self) -> &Self::Target {
15+
&self.0
16+
}
17+
}
18+
19+
impl DerefMut for Value {
20+
fn deref_mut(&mut self) -> &mut Self::Target {
21+
&mut self.0
22+
}
23+
}
24+
25+
impl Add for Value {
26+
type Output = Self;
27+
28+
fn add(self, rhs: Self) -> Self::Output {
29+
Self([
30+
self[0] + rhs[0],
31+
self[1] + rhs[1],
32+
self[2] + rhs[2],
33+
self[3] + rhs[3],
34+
])
35+
}
36+
}
37+
38+
impl AddAssign for Value {
39+
fn add_assign(&mut self, rhs: Self) {
40+
*self = *self + rhs;
41+
}
42+
}
43+
44+
pub const WG_SIZE_SHIFT: usize = 5;
45+
pub const WG_SIZE: usize = 1 << WG_SIZE_SHIFT;
46+
47+
// threads must be a literal, constants don't work
48+
const _: () = {
49+
assert!(WG_SIZE == 32);
50+
};
51+
#[spirv(compute(threads(32)))]
52+
pub fn main(
53+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] input: &[Value],
54+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] output: &mut Value,
55+
#[spirv(workgroup)] shared: &mut [Value; WG_SIZE],
56+
#[spirv(local_invocation_index)] inv_id: UVec3,
57+
) {
58+
unsafe {
59+
let inv_id = inv_id.x as usize;
60+
shared[inv_id] = input[inv_id];
61+
workgroup_memory_barrier_with_group_sync();
62+
63+
let mut mask = WG_SIZE << 1;
64+
while mask != 0 {
65+
if inv_id < mask {
66+
shared[inv_id] += shared[inv_id + mask];
67+
}
68+
workgroup_memory_barrier_with_group_sync();
69+
mask <<= 1;
70+
}
71+
72+
if inv_id == 0 {
73+
*output = shared[0];
74+
}
75+
}
76+
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// build-pass
2+
3+
use core::ops::{Add, AddAssign};
4+
use spirv_std::arch::workgroup_memory_barrier_with_group_sync;
5+
use spirv_std::glam::*;
6+
use spirv_std::spirv;
7+
8+
#[derive(Copy, Clone, Debug)]
9+
pub struct Nested(pub i32);
10+
11+
impl Add for Nested {
12+
type Output = Self;
13+
14+
fn add(self, rhs: Self) -> Self::Output {
15+
Self(self.0 + rhs.0)
16+
}
17+
}
18+
19+
impl AddAssign for Nested {
20+
fn add_assign(&mut self, rhs: Self) {
21+
*self = *self + rhs;
22+
}
23+
}
24+
25+
#[derive(Copy, Clone, Debug)]
26+
pub struct Value {
27+
pub a: f32,
28+
pub b: UVec4,
29+
pub c: Nested,
30+
pub d: Mat4,
31+
}
32+
33+
impl Add for Value {
34+
type Output = Self;
35+
36+
fn add(self, rhs: Self) -> Self::Output {
37+
Self {
38+
a: self.a + rhs.a,
39+
b: self.b + rhs.b,
40+
c: self.c + rhs.c,
41+
d: self.d + rhs.d,
42+
}
43+
}
44+
}
45+
46+
impl AddAssign for Value {
47+
fn add_assign(&mut self, rhs: Self) {
48+
*self = *self + rhs;
49+
}
50+
}
51+
52+
pub const WG_SIZE_SHIFT: usize = 5;
53+
pub const WG_SIZE: usize = 1 << WG_SIZE_SHIFT;
54+
55+
// threads must be a literal, constants don't work
56+
const _: () = {
57+
assert!(WG_SIZE == 32);
58+
};
59+
#[spirv(compute(threads(32)))]
60+
pub fn main(
61+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] input: &[Value],
62+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] output: &mut Value,
63+
#[spirv(workgroup)] shared: &mut [Value; WG_SIZE],
64+
#[spirv(local_invocation_index)] inv_id: UVec3,
65+
) {
66+
unsafe {
67+
let inv_id = inv_id.x as usize;
68+
shared[inv_id] = input[inv_id];
69+
workgroup_memory_barrier_with_group_sync();
70+
71+
let mut mask = WG_SIZE << 1;
72+
while mask != 0 {
73+
if inv_id < mask {
74+
shared[inv_id] += shared[inv_id + mask];
75+
}
76+
workgroup_memory_barrier_with_group_sync();
77+
mask <<= 1;
78+
}
79+
80+
if inv_id == 0 {
81+
*output = shared[0];
82+
}
83+
}
84+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// build-pass
2+
3+
use spirv_std::arch::workgroup_memory_barrier_with_group_sync;
4+
use spirv_std::glam::*;
5+
use spirv_std::spirv;
6+
7+
pub type Value = i32;
8+
9+
pub const WG_SIZE_SHIFT: usize = 5;
10+
pub const WG_SIZE: usize = 1 << WG_SIZE_SHIFT;
11+
12+
// threads must be a literal, constants don't work
13+
const _: () = {
14+
assert!(WG_SIZE == 32);
15+
};
16+
#[spirv(compute(threads(32)))]
17+
pub fn main(
18+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] input: &[Value],
19+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] output: &mut Value,
20+
#[spirv(workgroup)] shared: &mut [Value; WG_SIZE],
21+
#[spirv(local_invocation_index)] inv_id: UVec3,
22+
) {
23+
unsafe {
24+
let inv_id = inv_id.x as usize;
25+
shared[inv_id] = input[inv_id];
26+
workgroup_memory_barrier_with_group_sync();
27+
28+
let mut mask = WG_SIZE << 1;
29+
while mask != 0 {
30+
if inv_id < mask {
31+
shared[inv_id] += shared[inv_id + mask];
32+
}
33+
workgroup_memory_barrier_with_group_sync();
34+
mask <<= 1;
35+
}
36+
37+
if inv_id == 0 {
38+
*output = shared[0];
39+
}
40+
}
41+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// build-pass
2+
3+
use spirv_std::arch::workgroup_memory_barrier_with_group_sync;
4+
use spirv_std::glam::*;
5+
use spirv_std::spirv;
6+
7+
pub type Value = UVec4;
8+
9+
pub const WG_SIZE_SHIFT: usize = 5;
10+
pub const WG_SIZE: usize = 1 << WG_SIZE_SHIFT;
11+
12+
// threads must be a literal, constants don't work
13+
const _: () = {
14+
assert!(WG_SIZE == 32);
15+
};
16+
#[spirv(compute(threads(32)))]
17+
pub fn main(
18+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] input: &[Value],
19+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] output: &mut Value,
20+
#[spirv(workgroup)] shared: &mut [Value; WG_SIZE],
21+
#[spirv(local_invocation_index)] inv_id: UVec3,
22+
) {
23+
unsafe {
24+
let inv_id = inv_id.x as usize;
25+
shared[inv_id] = input[inv_id];
26+
workgroup_memory_barrier_with_group_sync();
27+
28+
let mut mask = WG_SIZE << 1;
29+
while mask != 0 {
30+
if inv_id < mask {
31+
shared[inv_id] += shared[inv_id + mask];
32+
}
33+
workgroup_memory_barrier_with_group_sync();
34+
mask <<= 1;
35+
}
36+
37+
if inv_id == 0 {
38+
*output = shared[0];
39+
}
40+
}
41+
}

0 commit comments

Comments
 (0)