Skip to content

Commit 9509d90

Browse files
committed
[naga hlsl-out] Handle additional cases of Cx2 matrices
Fixes #4423
1 parent 3d0fe3a commit 9509d90

13 files changed

+1489
-235
lines changed

naga/src/back/hlsl/mod.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,17 @@ type should be stored in `uniform` and `storage` buffers. The HLSL we
1313
generate must access values in that form, even when it is not what
1414
HLSL would use normally.
1515
16-
The rules described here only apply to WGSL `uniform` variables. WGSL
17-
`storage` buffers are translated as HLSL `ByteAddressBuffers`, for
18-
which we generate `Load` and `Store` method calls with explicit byte
19-
offsets. WGSL pipeline inputs must be scalars or vectors; they cannot
20-
be matrices, which is where the interesting problems arise.
16+
Matching the WGSL memory layout is a concern only for `uniform`
17+
variables. WGSL `storage` buffers are translated as HLSL
18+
`ByteAddressBuffers`, for which we generate `Load` and `Store` method
19+
calls with explicit byte offsets. WGSL pipeline inputs must be scalars
20+
or vectors; they cannot be matrices, which is where the interesting
21+
problems arise. However, when an affected type appears in a struct
22+
definition, the transformations described here are applied without
23+
consideration of where the struct is used.
24+
25+
Access to storage buffers is implemented in `storage.rs`. Access to
26+
uniform buffers is implemented where applicable in `writer.rs`.
2127
2228
## Row- and column-major ordering for matrices
2329
@@ -57,10 +63,9 @@ that the columns of a `matKx2<f32>` need only be [aligned as required
5763
for `vec2<f32>`][ilov], which is [eight-byte alignment][8bb].
5864
5965
To compensate for this, any time a `matKx2<f32>` appears in a WGSL
60-
`uniform` variable, whether directly as the variable's type or as part
61-
of a struct/array, we actually emit `K` separate `float2` members, and
62-
assemble/disassemble the matrix from its columns (in WGSL; rows in
63-
HLSL) upon load and store.
66+
`uniform` value or as part of a struct/array, we actually emit `K`
67+
separate `float2` members, and assemble/disassemble the matrix from its
68+
columns (in WGSL; rows in HLSL) upon load and store.
6469
6570
For example, the following WGSL struct type:
6671

naga/src/back/hlsl/storage.rs

Lines changed: 113 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@ pub(super) enum StoreValue {
108108
base: Handle<crate::Type>,
109109
member_index: u32,
110110
},
111+
// Access to a single column of a Cx2 matrix within a struct
112+
TempColumnAccess {
113+
depth: usize,
114+
base: Handle<crate::Type>,
115+
member_index: u32,
116+
column: u32,
117+
},
111118
}
112119

113120
impl<W: fmt::Write> super::Writer<'_, W> {
@@ -290,6 +297,15 @@ impl<W: fmt::Write> super::Writer<'_, W> {
290297
let name = &self.names[&NameKey::StructMember(base, member_index)];
291298
write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}")?
292299
}
300+
StoreValue::TempColumnAccess {
301+
depth,
302+
base,
303+
member_index,
304+
column,
305+
} => {
306+
let name = &self.names[&NameKey::StructMember(base, member_index)];
307+
write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}_{column}")?
308+
}
293309
}
294310
Ok(())
295311
}
@@ -302,6 +318,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
302318
value: StoreValue,
303319
func_ctx: &FunctionCtx,
304320
level: crate::back::Level,
321+
within_struct: Option<Handle<crate::Type>>,
305322
) -> BackendResult {
306323
let temp_resolution;
307324
let ty_resolution = match value {
@@ -325,6 +342,9 @@ impl<W: fmt::Write> super::Writer<'_, W> {
325342
temp_resolution = TypeResolution::Handle(ty_handle);
326343
&temp_resolution
327344
}
345+
StoreValue::TempColumnAccess { .. } => {
346+
unreachable!("attempting write_storage_store for TempColumnAccess");
347+
}
328348
};
329349
match *ty_resolution.inner_with(&module.types) {
330350
crate::TypeInner::Scalar(scalar) => {
@@ -372,37 +392,92 @@ impl<W: fmt::Write> super::Writer<'_, W> {
372392
rows,
373393
scalar,
374394
} => {
375-
// first, assign the value to a temporary
376-
writeln!(self.out, "{level}{{")?;
377-
let depth = level.0 + 1;
378-
write!(
379-
self.out,
380-
"{}{}{}x{} {}{} = ",
381-
level.next(),
382-
scalar.to_hlsl_str()?,
383-
columns as u8,
384-
rows as u8,
385-
STORE_TEMP_NAME,
386-
depth,
387-
)?;
388-
self.write_store_value(module, &value, func_ctx)?;
389-
writeln!(self.out, ";")?;
390-
391395
// Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
392396
let row_stride = Alignment::from(rows) * scalar.width as u32;
393397

394-
// then iterate the stores
395-
for i in 0..columns as u32 {
396-
self.temp_access_chain
397-
.push(SubAccess::Offset(i * row_stride));
398-
let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
399-
let sv = StoreValue::TempIndex {
400-
depth,
401-
index: i,
402-
ty: TypeResolution::Value(ty_inner),
403-
};
404-
self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
405-
self.temp_access_chain.pop();
398+
writeln!(self.out, "{level}{{")?;
399+
400+
match within_struct {
401+
Some(containing_struct) if rows == crate::VectorSize::Bi => {
402+
// If we are within a struct, then the struct was already assigned to
403+
// a temporary, we don't need to make another.
404+
let mut chain = mem::take(&mut self.temp_access_chain);
405+
for i in 0..columns as u32 {
406+
chain.push(SubAccess::Offset(i * row_stride));
407+
// working around the borrow checker in `self.write_expr`
408+
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
409+
let StoreValue::TempAccess { member_index, .. } = value else {
410+
unreachable!(
411+
"write_storage_store within_struct but not TempAccess"
412+
);
413+
};
414+
let column_value = StoreValue::TempColumnAccess {
415+
depth: level.0, // note not incrementing, b/c no temp
416+
base: containing_struct,
417+
member_index,
418+
column: i,
419+
};
420+
// See note about DXC and Load/Store in the module's documentation.
421+
if scalar.width == 4 {
422+
write!(
423+
self.out,
424+
"{}{}.Store{}(",
425+
level.next(),
426+
var_name,
427+
rows as u8
428+
)?;
429+
self.write_storage_address(module, &chain, func_ctx)?;
430+
write!(self.out, ", asuint(")?;
431+
self.write_store_value(module, &column_value, func_ctx)?;
432+
writeln!(self.out, "));")?;
433+
} else {
434+
write!(self.out, "{}{var_name}.Store(", level.next())?;
435+
self.write_storage_address(module, &chain, func_ctx)?;
436+
write!(self.out, ", ")?;
437+
self.write_store_value(module, &column_value, func_ctx)?;
438+
writeln!(self.out, ");")?;
439+
}
440+
chain.pop();
441+
}
442+
self.temp_access_chain = chain;
443+
}
444+
_ => {
445+
// first, assign the value to a temporary
446+
let depth = level.0 + 1;
447+
write!(
448+
self.out,
449+
"{}{}{}x{} {}{} = ",
450+
level.next(),
451+
scalar.to_hlsl_str()?,
452+
columns as u8,
453+
rows as u8,
454+
STORE_TEMP_NAME,
455+
depth,
456+
)?;
457+
self.write_store_value(module, &value, func_ctx)?;
458+
writeln!(self.out, ";")?;
459+
460+
// then iterate the stores
461+
for i in 0..columns as u32 {
462+
self.temp_access_chain
463+
.push(SubAccess::Offset(i * row_stride));
464+
let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
465+
let sv = StoreValue::TempIndex {
466+
depth,
467+
index: i,
468+
ty: TypeResolution::Value(ty_inner),
469+
};
470+
self.write_storage_store(
471+
module,
472+
var_handle,
473+
sv,
474+
func_ctx,
475+
level.next(),
476+
None,
477+
)?;
478+
self.temp_access_chain.pop();
479+
}
480+
}
406481
}
407482
// done
408483
writeln!(self.out, "{level}}}")?;
@@ -415,7 +490,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
415490
// first, assign the value to a temporary
416491
writeln!(self.out, "{level}{{")?;
417492
write!(self.out, "{}", level.next())?;
418-
self.write_value_type(module, &module.types[base].inner)?;
493+
self.write_type(module, base)?;
419494
let depth = level.next().0;
420495
write!(self.out, " {STORE_TEMP_NAME}{depth}")?;
421496
self.write_array_size(module, base, crate::ArraySize::Constant(size))?;
@@ -430,7 +505,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
430505
index: i,
431506
ty: TypeResolution::Handle(base),
432507
};
433-
self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
508+
self.write_storage_store(module, var_handle, sv, func_ctx, level.next(), None)?;
434509
self.temp_access_chain.pop();
435510
}
436511
// done
@@ -461,7 +536,14 @@ impl<W: fmt::Write> super::Writer<'_, W> {
461536
base: struct_ty,
462537
member_index: i as u32,
463538
};
464-
self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
539+
self.write_storage_store(
540+
module,
541+
var_handle,
542+
sv,
543+
func_ctx,
544+
level.next(),
545+
Some(struct_ty),
546+
)?;
465547
self.temp_access_chain.pop();
466548
}
467549
// done

naga/src/back/hlsl/writer.rs

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1945,6 +1945,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
19451945
StoreValue::Expression(value),
19461946
func_ctx,
19471947
level,
1948+
None,
19481949
)?;
19491950
} else {
19501951
// We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
@@ -2963,6 +2964,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
29632964
rows: crate::VectorSize::Bi,
29642965
width: 4,
29652966
}) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
2967+
.or_else(|| get_global_uniform_matrix(module, base, func_ctx))
29662968
{
29672969
write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
29682970
self.write_expr(module, base, func_ctx)?;
@@ -3075,13 +3077,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
30753077
{
30763078
// do nothing, the chain is written on `Load`/`Store`
30773079
} else {
3078-
// We write the matrix column access in a special way since
3079-
// the type of `base` is our special __matCx2 struct.
3080+
// See if we need to write the matrix column access in a
3081+
// special way since the type of `base` is our special
3082+
// __matCx2 struct.
30803083
if let Some(MatrixType {
30813084
rows: crate::VectorSize::Bi,
30823085
width: 4,
30833086
..
30843087
}) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3088+
.or_else(|| get_global_uniform_matrix(module, base, func_ctx))
30853089
{
30863090
self.write_expr(module, base, func_ctx)?;
30873091
write!(self.out, "._{index}")?;
@@ -3381,8 +3385,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
33813385
.or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
33823386
{
33833387
let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3384-
if let TypeInner::Pointer { base, .. } = *resolved {
3385-
resolved = &module.types[base].inner;
3388+
let ptr_tr = resolved.pointer_base_type();
3389+
if let Some(ptr_ty) =
3390+
ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3391+
{
3392+
resolved = ptr_ty;
33863393
}
33873394

33883395
write!(self.out, "((")?;
@@ -4416,6 +4423,32 @@ pub(super) fn get_inner_matrix_data(
44164423
}
44174424
}
44184425

4426+
fn find_matrix_in_access_chain(
4427+
module: &Module,
4428+
base: Handle<crate::Expression>,
4429+
func_ctx: &back::FunctionCtx<'_>,
4430+
) -> Option<Handle<crate::Expression>> {
4431+
let mut current_base = base;
4432+
loop {
4433+
let resolved_tr = func_ctx
4434+
.resolve_type(current_base, &module.types)
4435+
.pointer_base_type();
4436+
let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4437+
4438+
match *resolved {
4439+
TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4440+
TypeInner::Matrix { .. } => return Some(current_base),
4441+
_ => return None,
4442+
}
4443+
4444+
current_base = match func_ctx.expressions[current_base] {
4445+
crate::Expression::Access { base, .. } => base,
4446+
crate::Expression::AccessIndex { base, .. } => base,
4447+
_ => return None,
4448+
}
4449+
}
4450+
}
4451+
44194452
/// Returns the matrix data if the access chain starting at `base`:
44204453
/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
44214454
/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
@@ -4474,6 +4507,36 @@ pub(super) fn get_inner_matrix_of_struct_array_member(
44744507
None
44754508
}
44764509

4510+
/// Simpler version of get_inner_matrix_of_global_uniform that only looks at the
4511+
/// immediate expression, rather than traversing an access chain.
4512+
fn get_global_uniform_matrix(
4513+
module: &Module,
4514+
base: Handle<crate::Expression>,
4515+
func_ctx: &back::FunctionCtx<'_>,
4516+
) -> Option<MatrixType> {
4517+
let base_tr = func_ctx
4518+
.resolve_type(base, &module.types)
4519+
.pointer_base_type();
4520+
let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
4521+
match (&func_ctx.expressions[base], base_ty) {
4522+
(
4523+
&crate::Expression::GlobalVariable(handle),
4524+
Some(&TypeInner::Matrix {
4525+
columns,
4526+
rows,
4527+
scalar,
4528+
}),
4529+
) if module.global_variables[handle].space == crate::AddressSpace::Uniform => {
4530+
Some(MatrixType {
4531+
columns,
4532+
rows,
4533+
width: scalar.width,
4534+
})
4535+
}
4536+
_ => None,
4537+
}
4538+
}
4539+
44774540
/// Returns the matrix data if the access chain starting at `base`:
44784541
/// - starts with an expression with resolved type of [`TypeInner::Matrix`]
44794542
/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]

naga/tests/in/wgsl/access.wgsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ var<uniform> baz: Baz;
3535
var<storage,read_write> qux: vec2<i32>;
3636

3737
fn test_matrix_within_struct_accesses() {
38+
// Test HLSL accesses to Cx2 matrices. There are additional tests
39+
// in `hlsl_mat_cx2.wgsl`.
40+
3841
var idx = 1;
3942

4043
idx--;

naga/tests/in/wgsl/hlsl_mat_cx2.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
targets = "HLSL"

0 commit comments

Comments
 (0)