Skip to content

Commit cd33c46

Browse files
committed
Set location values for input/output struct values.
Use member locations only for input/output blocks Standard input/output variables that use structs now set variable locations rather than member locations. This relaxes the struct requirements for inputs/outputs used for MSL. This also avoids outputting invalid SPIR-V, which can fail the validator as well as work incorrectly on some hardware. Incremented version to 1.7.0.
1 parent f1c9b57 commit cd33c46

File tree

12 files changed

+625
-369
lines changed

12 files changed

+625
-369
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ if (MSL_OUTPUT_DIR)
8080
endif()
8181

8282
set(MSL_MAJOR_VERSION 1)
83-
set(MSL_MINOR_VERSION 6)
84-
set(MSL_PATCH_VERSION 1)
83+
set(MSL_MINOR_VERSION 7)
84+
set(MSL_PATCH_VERSION 0)
8585
set(MSL_VERSION ${MSL_MAJOR_VERSION}.${MSL_MINOR_VERSION}.${MSL_PATCH_VERSION})
8686

8787
set(MSL_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})

Compile/include/MSL/Compile/Types.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,9 @@ struct ArrayInfo
840840

841841
/**
842842
* @brief The stride of the array.
843+
*
844+
* This will be unknown if not explicitly set by SPIR-V decorations, in which case the standard
845+
* data type will determine the stride.
843846
*/
844847
std::uint32_t stride;
845848
};

Compile/src/SpirVProcessor.cpp

Lines changed: 102 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2016-2019 Aaron Barany
2+
* Copyright 2016-2022 Aaron Barany
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -31,11 +31,11 @@ namespace msl
3131
namespace
3232
{
3333

34-
static const unsigned int minVersion = 0x00010000;
35-
static const unsigned int firstInstruction = 5;
36-
static const unsigned int unknownLength = (unsigned int)-1;
34+
const unsigned int minVersion = 0x00010000;
35+
const unsigned int firstInstruction = 5;
36+
const unsigned int unknownLength = (unsigned int)-1;
3737

38-
static const char* stageNames[] =
38+
const char* stageNames[] =
3939
{
4040
"vertex",
4141
"tessellation_control",
@@ -47,7 +47,7 @@ static const char* stageNames[] =
4747
static_assert(sizeof(stageNames)/sizeof(*stageNames) == stageCount,
4848
"stage name array is out of sync with enum");
4949

50-
static std::uint32_t typeSizes[] =
50+
std::uint32_t typeSizes[] =
5151
{
5252
// Scalars and vectors
5353
static_cast<std::uint32_t>(sizeof(float)), // Float
@@ -1266,9 +1266,10 @@ bool addInputsOutputs(Output& output, std::vector<SpirVProcessor::InputOutput>&
12661266

12671267
inputOutput.type = getType(arrayElements, inputOutput.structIndex, processor, data, typeId);
12681268
inputOutput.arrayElements = makeArrayLengths(arrayElements);
1269+
inputOutput.block = data.blocks.find(underlyingTypeId) != data.blocks.end();
12691270
inputOutput.patch = data.patchVars.find(inputOutputIndices.first) != data.patchVars.end();
12701271
inputOutput.autoAssigned = true;
1271-
if (inputOutput.type == Type::Struct)
1272+
if (inputOutput.block)
12721273
{
12731274
const Struct& structType = processor.structs[inputOutput.structIndex];
12741275

@@ -1282,19 +1283,6 @@ bool addInputsOutputs(Output& output, std::vector<SpirVProcessor::InputOutput>&
12821283
}
12831284
data.inputOutputStructs.insert(underlyingTypeId);
12841285

1285-
// Make sure there's no recursive structs.
1286-
for (const StructMember& member : structType.members)
1287-
{
1288-
if (member.type == Type::Struct)
1289-
{
1290-
output.addMessage(Output::Level::Error, processor.fileName, processor.line,
1291-
processor.column, false,
1292-
"linker error: " + ioName + " member " + structType.name + "." +
1293-
member.name + " is a struct");
1294-
return false;
1295-
}
1296-
}
1297-
12981286
// Don't allow arbitrary arrays of input/output blocks.
12991287
bool shouldBeArray = !inputOutput.patch && (&inputOutputs == &processor.inputs ?
13001288
inputIsArray(processor.stage) : outputIsArray(processor.stage));
@@ -1423,8 +1411,9 @@ bool addComponents(std::vector<std::uint8_t>& locations, std::size_t curLocation
14231411
}
14241412

14251413
bool fillLocation(std::vector<std::uint8_t>& locations, std::size_t& curLocation,
1426-
std::uint32_t component, Type type, const std::vector<std::uint32_t>& arrayElements,
1427-
bool removeFirstArray)
1414+
std::uint32_t component, Type type, std::uint32_t structIndex,
1415+
const std::vector<std::uint32_t>& arrayElements, bool removeFirstArray,
1416+
const std::vector<Struct>& structs)
14281417
{
14291418
assert(component < 4);
14301419

@@ -1552,8 +1541,10 @@ bool fillLocation(std::vector<std::uint8_t>& locations, std::size_t& curLocation
15521541
for (std::uint32_t i = 0; i < elementCount; ++i, ++curLocation)
15531542
{
15541543
if (!addComponents(locations, curLocation,
1555-
(1 << component) | (2 << component) | (4 << component)))
1544+
(1 << component) | (2 << component) | (4 << component)))
1545+
{
15561546
return false;
1547+
}
15571548
}
15581549
break;
15591550

@@ -1614,6 +1605,25 @@ bool fillLocation(std::vector<std::uint8_t>& locations, std::size_t& curLocation
16141605
}
16151606
break;
16161607

1608+
case Type::Struct:
1609+
{
1610+
if (component != 0)
1611+
return false;
1612+
for (std::uint32_t i = 0; i < elementCount; ++i)
1613+
{
1614+
for (const StructMember& member : structs[structIndex].members)
1615+
{
1616+
if (!fillLocation(locations, curLocation, component, member.type,
1617+
member.structIndex, makeArrayLengths(member.arrayElements), false,
1618+
structs))
1619+
{
1620+
return false;
1621+
}
1622+
}
1623+
}
1624+
break;
1625+
}
1626+
16171627
default:
16181628
assert(false);
16191629
return false;
@@ -1623,7 +1633,8 @@ bool fillLocation(std::vector<std::uint8_t>& locations, std::size_t& curLocation
16231633
}
16241634

16251635
bool assignInputsOutputs(Output& output, const SpirVProcessor& processor,
1626-
std::vector<SpirVProcessor::InputOutput>& inputsOutputs, bool removeFirstArray)
1636+
std::vector<SpirVProcessor::InputOutput>& inputsOutputs, bool removeFirstArray,
1637+
const std::vector<Struct>& structs)
16271638
{
16281639
std::string ioName = &inputsOutputs == &processor.inputs ? "input" : "output";
16291640
std::size_t curLocation = 0;
@@ -1633,7 +1644,7 @@ bool assignInputsOutputs(Output& output, const SpirVProcessor& processor,
16331644

16341645
for (SpirVProcessor::InputOutput& io : inputsOutputs)
16351646
{
1636-
if (io.type == Type::Struct)
1647+
if (io.block)
16371648
{
16381649
const Struct& ioStruct = processor.structs[io.structIndex];
16391650
if (io.memberLocations.empty() || io.memberLocations[0].first == unknown)
@@ -1654,6 +1665,7 @@ bool assignInputsOutputs(Output& output, const SpirVProcessor& processor,
16541665

16551666
for (std::size_t i = 0; i < ioStruct.members.size(); ++i)
16561667
{
1668+
const StructMember& member = ioStruct.members[i];
16571669
std::uint32_t component = 0;
16581670
if (io.memberLocations[i].first == unknown)
16591671
{
@@ -1666,8 +1678,9 @@ bool assignInputsOutputs(Output& output, const SpirVProcessor& processor,
16661678
component = io.memberLocations[i].second;
16671679
}
16681680

1669-
if (!fillLocation(locations, curLocation, component, ioStruct.members[i].type,
1670-
makeArrayLengths(ioStruct.members[i].arrayElements), false))
1681+
if (!fillLocation(locations, curLocation, component, member.type,
1682+
member.structIndex, makeArrayLengths(member.arrayElements), false,
1683+
structs))
16711684
{
16721685
output.addMessage(Output::Level::Error, processor.fileName, processor.line,
16731686
processor.column, false,
@@ -1693,8 +1706,8 @@ bool assignInputsOutputs(Output& output, const SpirVProcessor& processor,
16931706
hasExplicitLocations = true;
16941707
}
16951708

1696-
if (!fillLocation(locations, curLocation, component, io.type, io.arrayElements,
1697-
removeFirstArray))
1709+
if (!fillLocation(locations, curLocation, component, io.type, io.structIndex,
1710+
io.arrayElements, removeFirstArray, structs))
16981711
{
16991712
output.addMessage(Output::Level::Error, processor.fileName, processor.line,
17001713
processor.column, false,
@@ -1724,7 +1737,7 @@ bool findLinkedMember(Output& output, std::uint32_t& outputIndex, std::uint32_t&
17241737

17251738
for (std::uint32_t i = 0; i < processor.outputs.size(); ++i)
17261739
{
1727-
if (processor.outputs[i] .type != Type::Struct)
1740+
if (!processor.outputs[i].block)
17281741
continue;
17291742

17301743
const Struct& outputStruct = processor.structs[processor.outputs[i].structIndex];
@@ -1781,6 +1794,51 @@ bool inputOutputArraysEqual(const std::vector<std::uint32_t>& outputArray, bool
17811794
return true;
17821795
}
17831796

1797+
bool inputOutputArraysEqual(const std::vector<ArrayInfo>& outputArray,
1798+
const std::vector<ArrayInfo>& inputArray)
1799+
{
1800+
if (outputArray.size() != inputArray.size())
1801+
return false;
1802+
1803+
for (std::size_t i = 0; i < outputArray.size(); ++i)
1804+
{
1805+
if (outputArray[i].length != inputArray[i].length ||
1806+
outputArray[i].stride != inputArray[i].stride)
1807+
{
1808+
return false;
1809+
}
1810+
}
1811+
1812+
return true;
1813+
}
1814+
1815+
bool inputOutputStructsEqual(const Struct& outputStruct,
1816+
const std::vector<Struct>& outputStructTypes, const Struct& inputStruct,
1817+
const std::vector<Struct>& inputStructTypes)
1818+
{
1819+
if (outputStruct.members.size() != inputStruct.members.size())
1820+
return false;
1821+
1822+
for (std::size_t i = 0; i < outputStruct.members.size(); ++i)
1823+
{
1824+
const StructMember& outputMember = outputStruct.members[i];
1825+
const StructMember& inputMember = inputStruct.members[i];
1826+
if (outputMember.offset != inputMember.offset ||
1827+
outputMember.size != inputMember.size ||
1828+
outputMember.type != inputMember.type ||
1829+
!inputOutputArraysEqual(outputMember.arrayElements, inputMember.arrayElements) ||
1830+
outputMember.rowMajor != inputMember.rowMajor ||
1831+
(outputMember.type == Type::Struct && !inputOutputStructsEqual(
1832+
outputStructTypes[outputMember.structIndex], outputStructTypes,
1833+
inputStructTypes[inputMember.structIndex], inputStructTypes)))
1834+
{
1835+
return false;
1836+
}
1837+
}
1838+
1839+
return true;
1840+
}
1841+
17841842
void addDummyDescriptorSet(std::vector<std::uint32_t>& spirv, std::uint32_t id)
17851843
{
17861844
spirv.push_back((4 << spv::WordCountShift) | spv::OpDecorate);
@@ -2351,7 +2409,7 @@ bool SpirVProcessor::extract(Output& output, const std::string& fileName, std::s
23512409
encounteredNames.clear();
23522410
for (const InputOutput& stageInput : inputs)
23532411
{
2354-
if (stageInput.type == Type::Struct)
2412+
if (stageInput.block)
23552413
continue;
23562414

23572415
if (!encounteredNames.insert(stageInput.name).second)
@@ -2366,7 +2424,7 @@ bool SpirVProcessor::extract(Output& output, const std::string& fileName, std::s
23662424
encounteredNames.clear();
23672425
for (const InputOutput& stageOutput : outputs)
23682426
{
2369-
if (stageOutput.type == Type::Struct)
2427+
if (stageOutput.block)
23702428
continue;
23712429

23722430
if (!encounteredNames.insert(stageOutput.name).second)
@@ -2463,12 +2521,12 @@ bool SpirVProcessor::uniformsCompatible(Output& output, const SpirVProcessor& ot
24632521

24642522
bool SpirVProcessor::assignInputs(Output& output)
24652523
{
2466-
return assignInputsOutputs(output, *this, inputs, inputIsArray(stage));
2524+
return assignInputsOutputs(output, *this, inputs, inputIsArray(stage), structs);
24672525
}
24682526

24692527
bool SpirVProcessor::assignOutputs(Output& output)
24702528
{
2471-
return assignInputsOutputs(output, *this, outputs, outputIsArray(stage));
2529+
return assignInputsOutputs(output, *this, outputs, outputIsArray(stage), structs);
24722530
}
24732531

24742532
bool SpirVProcessor::linkInputs(Output& output, const SpirVProcessor& prevStage)
@@ -2478,7 +2536,7 @@ bool SpirVProcessor::linkInputs(Output& output, const SpirVProcessor& prevStage)
24782536
bool outputArrays = outputIsArray(prevStage.stage);
24792537
for (InputOutput& input : inputs)
24802538
{
2481-
if (input.type == Type::Struct)
2539+
if (input.block)
24822540
{
24832541
Struct& inputStruct = structs[input.structIndex];
24842542
assert(inputStruct.members.size() == input.memberLocations.size());
@@ -2489,7 +2547,7 @@ bool SpirVProcessor::linkInputs(Output& output, const SpirVProcessor& prevStage)
24892547

24902548
std::uint32_t otherOutIndex, otherMemberIndex;
24912549
if (!findLinkedMember(output, otherOutIndex, otherMemberIndex, prevStage,
2492-
inputStruct.members[i].name))
2550+
inputStruct.members[i].name))
24932551
{
24942552
success = false;
24952553
continue;
@@ -2500,8 +2558,8 @@ bool SpirVProcessor::linkInputs(Output& output, const SpirVProcessor& prevStage)
25002558
if (inputStruct.members[i].type != outputStruct.members[otherMemberIndex].type ||
25012559
input.patch != prevStage.outputs[otherOutIndex].patch ||
25022560
!inputOutputArraysEqual(
2503-
makeArrayLengths(outputStruct.members[otherMemberIndex].arrayElements),
2504-
false, makeArrayLengths(inputStruct.members[i].arrayElements), false))
2561+
outputStruct.members[otherMemberIndex].arrayElements,
2562+
inputStruct.members[i].arrayElements))
25052563
{
25062564
output.addMessage(Output::Level::Error, fileName, line, column, false,
25072565
"linker error: type mismatch when linking input member " +
@@ -2529,7 +2587,10 @@ bool SpirVProcessor::linkInputs(Output& output, const SpirVProcessor& prevStage)
25292587
found = true;
25302588
if (input.type != out.type || input.patch != out.patch ||
25312589
!inputOutputArraysEqual(out.arrayElements, outputArrays && !out.patch,
2532-
input.arrayElements, inputArrays && !input.patch))
2590+
input.arrayElements, inputArrays && !input.patch) ||
2591+
(input.type == Type::Struct && !inputOutputStructsEqual(
2592+
prevStage.structs[out.structIndex], prevStage.structs,
2593+
structs[input.structIndex], structs)))
25332594
{
25342595
output.addMessage(Output::Level::Error, fileName, line, column, false,
25352596
"linker error: type mismatch when linking input " + input.name +
@@ -2674,7 +2735,7 @@ std::vector<std::uint32_t> SpirVProcessor::process(Strip strip, bool dummyBindin
26742735
if (!inputs[j].autoAssigned)
26752736
continue;
26762737

2677-
if (inputs[j].type == Type::Struct)
2738+
if (inputs[j].block)
26782739
{
26792740
std::uint32_t typeId = structIds[inputs[j].structIndex];
26802741
auto foundMember = memberLocations.find(typeId);
@@ -2707,7 +2768,7 @@ std::vector<std::uint32_t> SpirVProcessor::process(Strip strip, bool dummyBindin
27072768
if (!outputs[j].autoAssigned)
27082769
continue;
27092770

2710-
if (outputs[j].type == Type::Struct)
2771+
if (outputs[j].block)
27112772
{
27122773
std::uint32_t typeId = structIds[outputs[j].structIndex];
27132774
auto foundMember = memberLocations.find(typeId);

Compile/src/SpirVProcessor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class MSL_COMPILE_EXPORT SpirVProcessor
4949
std::uint32_t structIndex;
5050
std::vector<std::uint32_t> arrayElements;
5151
std::vector<std::pair<std::uint32_t, std::uint32_t>> memberLocations;
52+
bool block;
5253
bool patch;
5354
bool autoAssigned;
5455
std::uint32_t location;

0 commit comments

Comments
 (0)