11/*
2- * Copyright 2016-2019 Aaron Barany
2+ * Copyright 2016-2022 Aaron Barany
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
@@ -31,11 +31,11 @@ namespace msl
3131namespace
3232{
3333
34- static const unsigned int minVersion = 0x00010000 ;
35- static const unsigned int firstInstruction = 5 ;
36- static const unsigned int unknownLength = (unsigned int )-1 ;
34+ const unsigned int minVersion = 0x00010000 ;
35+ const unsigned int firstInstruction = 5 ;
36+ const unsigned int unknownLength = (unsigned int )-1 ;
3737
38- static const char * stageNames[] =
38+ const char * stageNames[] =
3939{
4040 " vertex" ,
4141 " tessellation_control" ,
@@ -47,7 +47,7 @@ static const char* stageNames[] =
4747static_assert (sizeof (stageNames)/sizeof (*stageNames) == stageCount,
4848 " stage name array is out of sync with enum" );
4949
50- static std::uint32_t typeSizes[] =
50+ std::uint32_t typeSizes[] =
5151{
5252 // Scalars and vectors
5353 static_cast <std::uint32_t >(sizeof (float )), // Float
@@ -1266,9 +1266,10 @@ bool addInputsOutputs(Output& output, std::vector<SpirVProcessor::InputOutput>&
12661266
12671267 inputOutput.type = getType (arrayElements, inputOutput.structIndex , processor, data, typeId);
12681268 inputOutput.arrayElements = makeArrayLengths (arrayElements);
1269+ inputOutput.block = data.blocks .find (underlyingTypeId) != data.blocks .end ();
12691270 inputOutput.patch = data.patchVars .find (inputOutputIndices.first ) != data.patchVars .end ();
12701271 inputOutput.autoAssigned = true ;
1271- if (inputOutput.type == Type::Struct )
1272+ if (inputOutput.block )
12721273 {
12731274 const Struct& structType = processor.structs [inputOutput.structIndex ];
12741275
@@ -1282,19 +1283,6 @@ bool addInputsOutputs(Output& output, std::vector<SpirVProcessor::InputOutput>&
12821283 }
12831284 data.inputOutputStructs .insert (underlyingTypeId);
12841285
1285- // Make sure there's no recursive structs.
1286- for (const StructMember& member : structType.members )
1287- {
1288- if (member.type == Type::Struct)
1289- {
1290- output.addMessage (Output::Level::Error, processor.fileName , processor.line ,
1291- processor.column , false ,
1292- " linker error: " + ioName + " member " + structType.name + " ." +
1293- member.name + " is a struct" );
1294- return false ;
1295- }
1296- }
1297-
12981286 // Don't allow arbitrary arrays of input/output blocks.
12991287 bool shouldBeArray = !inputOutput.patch && (&inputOutputs == &processor.inputs ?
13001288 inputIsArray (processor.stage ) : outputIsArray (processor.stage ));
@@ -1423,8 +1411,9 @@ bool addComponents(std::vector<std::uint8_t>& locations, std::size_t curLocation
14231411}
14241412
14251413bool fillLocation (std::vector<std::uint8_t >& locations, std::size_t & curLocation,
1426- std::uint32_t component, Type type, const std::vector<std::uint32_t >& arrayElements,
1427- bool removeFirstArray)
1414+ std::uint32_t component, Type type, std::uint32_t structIndex,
1415+ const std::vector<std::uint32_t >& arrayElements, bool removeFirstArray,
1416+ const std::vector<Struct>& structs)
14281417{
14291418 assert (component < 4 );
14301419
@@ -1552,8 +1541,10 @@ bool fillLocation(std::vector<std::uint8_t>& locations, std::size_t& curLocation
15521541 for (std::uint32_t i = 0 ; i < elementCount; ++i, ++curLocation)
15531542 {
15541543 if (!addComponents (locations, curLocation,
1555- (1 << component) | (2 << component) | (4 << component)))
1544+ (1 << component) | (2 << component) | (4 << component)))
1545+ {
15561546 return false ;
1547+ }
15571548 }
15581549 break ;
15591550
@@ -1614,6 +1605,25 @@ bool fillLocation(std::vector<std::uint8_t>& locations, std::size_t& curLocation
16141605 }
16151606 break ;
16161607
1608+ case Type::Struct:
1609+ {
1610+ if (component != 0 )
1611+ return false ;
1612+ for (std::uint32_t i = 0 ; i < elementCount; ++i)
1613+ {
1614+ for (const StructMember& member : structs[structIndex].members )
1615+ {
1616+ if (!fillLocation (locations, curLocation, component, member.type ,
1617+ member.structIndex , makeArrayLengths (member.arrayElements ), false ,
1618+ structs))
1619+ {
1620+ return false ;
1621+ }
1622+ }
1623+ }
1624+ break ;
1625+ }
1626+
16171627 default :
16181628 assert (false );
16191629 return false ;
@@ -1623,7 +1633,8 @@ bool fillLocation(std::vector<std::uint8_t>& locations, std::size_t& curLocation
16231633}
16241634
16251635bool assignInputsOutputs (Output& output, const SpirVProcessor& processor,
1626- std::vector<SpirVProcessor::InputOutput>& inputsOutputs, bool removeFirstArray)
1636+ std::vector<SpirVProcessor::InputOutput>& inputsOutputs, bool removeFirstArray,
1637+ const std::vector<Struct>& structs)
16271638{
16281639 std::string ioName = &inputsOutputs == &processor.inputs ? " input" : " output" ;
16291640 std::size_t curLocation = 0 ;
@@ -1633,7 +1644,7 @@ bool assignInputsOutputs(Output& output, const SpirVProcessor& processor,
16331644
16341645 for (SpirVProcessor::InputOutput& io : inputsOutputs)
16351646 {
1636- if (io.type == Type::Struct )
1647+ if (io.block )
16371648 {
16381649 const Struct& ioStruct = processor.structs [io.structIndex ];
16391650 if (io.memberLocations .empty () || io.memberLocations [0 ].first == unknown)
@@ -1654,6 +1665,7 @@ bool assignInputsOutputs(Output& output, const SpirVProcessor& processor,
16541665
16551666 for (std::size_t i = 0 ; i < ioStruct.members .size (); ++i)
16561667 {
1668+ const StructMember& member = ioStruct.members [i];
16571669 std::uint32_t component = 0 ;
16581670 if (io.memberLocations [i].first == unknown)
16591671 {
@@ -1666,8 +1678,9 @@ bool assignInputsOutputs(Output& output, const SpirVProcessor& processor,
16661678 component = io.memberLocations [i].second ;
16671679 }
16681680
1669- if (!fillLocation (locations, curLocation, component, ioStruct.members [i].type ,
1670- makeArrayLengths (ioStruct.members [i].arrayElements ), false ))
1681+ if (!fillLocation (locations, curLocation, component, member.type ,
1682+ member.structIndex , makeArrayLengths (member.arrayElements ), false ,
1683+ structs))
16711684 {
16721685 output.addMessage (Output::Level::Error, processor.fileName , processor.line ,
16731686 processor.column , false ,
@@ -1693,8 +1706,8 @@ bool assignInputsOutputs(Output& output, const SpirVProcessor& processor,
16931706 hasExplicitLocations = true ;
16941707 }
16951708
1696- if (!fillLocation (locations, curLocation, component, io.type , io.arrayElements ,
1697- removeFirstArray))
1709+ if (!fillLocation (locations, curLocation, component, io.type , io.structIndex ,
1710+ io. arrayElements , removeFirstArray, structs ))
16981711 {
16991712 output.addMessage (Output::Level::Error, processor.fileName , processor.line ,
17001713 processor.column , false ,
@@ -1724,7 +1737,7 @@ bool findLinkedMember(Output& output, std::uint32_t& outputIndex, std::uint32_t&
17241737
17251738 for (std::uint32_t i = 0 ; i < processor.outputs .size (); ++i)
17261739 {
1727- if (processor.outputs [i] . type != Type::Struct )
1740+ if (! processor.outputs [i]. block )
17281741 continue ;
17291742
17301743 const Struct& outputStruct = processor.structs [processor.outputs [i].structIndex ];
@@ -1781,6 +1794,51 @@ bool inputOutputArraysEqual(const std::vector<std::uint32_t>& outputArray, bool
17811794 return true ;
17821795}
17831796
1797+ bool inputOutputArraysEqual (const std::vector<ArrayInfo>& outputArray,
1798+ const std::vector<ArrayInfo>& inputArray)
1799+ {
1800+ if (outputArray.size () != inputArray.size ())
1801+ return false ;
1802+
1803+ for (std::size_t i = 0 ; i < outputArray.size (); ++i)
1804+ {
1805+ if (outputArray[i].length != inputArray[i].length ||
1806+ outputArray[i].stride != inputArray[i].stride )
1807+ {
1808+ return false ;
1809+ }
1810+ }
1811+
1812+ return true ;
1813+ }
1814+
1815+ bool inputOutputStructsEqual (const Struct& outputStruct,
1816+ const std::vector<Struct>& outputStructTypes, const Struct& inputStruct,
1817+ const std::vector<Struct>& inputStructTypes)
1818+ {
1819+ if (outputStruct.members .size () != inputStruct.members .size ())
1820+ return false ;
1821+
1822+ for (std::size_t i = 0 ; i < outputStruct.members .size (); ++i)
1823+ {
1824+ const StructMember& outputMember = outputStruct.members [i];
1825+ const StructMember& inputMember = inputStruct.members [i];
1826+ if (outputMember.offset != inputMember.offset ||
1827+ outputMember.size != inputMember.size ||
1828+ outputMember.type != inputMember.type ||
1829+ !inputOutputArraysEqual (outputMember.arrayElements , inputMember.arrayElements ) ||
1830+ outputMember.rowMajor != inputMember.rowMajor ||
1831+ (outputMember.type == Type::Struct && !inputOutputStructsEqual (
1832+ outputStructTypes[outputMember.structIndex ], outputStructTypes,
1833+ inputStructTypes[inputMember.structIndex ], inputStructTypes)))
1834+ {
1835+ return false ;
1836+ }
1837+ }
1838+
1839+ return true ;
1840+ }
1841+
17841842void addDummyDescriptorSet (std::vector<std::uint32_t >& spirv, std::uint32_t id)
17851843{
17861844 spirv.push_back ((4 << spv::WordCountShift) | spv::OpDecorate);
@@ -2351,7 +2409,7 @@ bool SpirVProcessor::extract(Output& output, const std::string& fileName, std::s
23512409 encounteredNames.clear ();
23522410 for (const InputOutput& stageInput : inputs)
23532411 {
2354- if (stageInput.type == Type::Struct )
2412+ if (stageInput.block )
23552413 continue ;
23562414
23572415 if (!encounteredNames.insert (stageInput.name ).second )
@@ -2366,7 +2424,7 @@ bool SpirVProcessor::extract(Output& output, const std::string& fileName, std::s
23662424 encounteredNames.clear ();
23672425 for (const InputOutput& stageOutput : outputs)
23682426 {
2369- if (stageOutput.type == Type::Struct )
2427+ if (stageOutput.block )
23702428 continue ;
23712429
23722430 if (!encounteredNames.insert (stageOutput.name ).second )
@@ -2463,12 +2521,12 @@ bool SpirVProcessor::uniformsCompatible(Output& output, const SpirVProcessor& ot
24632521
24642522bool SpirVProcessor::assignInputs (Output& output)
24652523{
2466- return assignInputsOutputs (output, *this , inputs, inputIsArray (stage));
2524+ return assignInputsOutputs (output, *this , inputs, inputIsArray (stage), structs );
24672525}
24682526
24692527bool SpirVProcessor::assignOutputs (Output& output)
24702528{
2471- return assignInputsOutputs (output, *this , outputs, outputIsArray (stage));
2529+ return assignInputsOutputs (output, *this , outputs, outputIsArray (stage), structs );
24722530}
24732531
24742532bool SpirVProcessor::linkInputs (Output& output, const SpirVProcessor& prevStage)
@@ -2478,7 +2536,7 @@ bool SpirVProcessor::linkInputs(Output& output, const SpirVProcessor& prevStage)
24782536 bool outputArrays = outputIsArray (prevStage.stage );
24792537 for (InputOutput& input : inputs)
24802538 {
2481- if (input.type == Type::Struct )
2539+ if (input.block )
24822540 {
24832541 Struct& inputStruct = structs[input.structIndex ];
24842542 assert (inputStruct.members .size () == input.memberLocations .size ());
@@ -2489,7 +2547,7 @@ bool SpirVProcessor::linkInputs(Output& output, const SpirVProcessor& prevStage)
24892547
24902548 std::uint32_t otherOutIndex, otherMemberIndex;
24912549 if (!findLinkedMember (output, otherOutIndex, otherMemberIndex, prevStage,
2492- inputStruct.members [i].name ))
2550+ inputStruct.members [i].name ))
24932551 {
24942552 success = false ;
24952553 continue ;
@@ -2500,8 +2558,8 @@ bool SpirVProcessor::linkInputs(Output& output, const SpirVProcessor& prevStage)
25002558 if (inputStruct.members [i].type != outputStruct.members [otherMemberIndex].type ||
25012559 input.patch != prevStage.outputs [otherOutIndex].patch ||
25022560 !inputOutputArraysEqual (
2503- makeArrayLengths ( outputStruct.members [otherMemberIndex].arrayElements ) ,
2504- false , makeArrayLengths ( inputStruct.members [i].arrayElements ), false ))
2561+ outputStruct.members [otherMemberIndex].arrayElements ,
2562+ inputStruct.members [i].arrayElements ))
25052563 {
25062564 output.addMessage (Output::Level::Error, fileName, line, column, false ,
25072565 " linker error: type mismatch when linking input member " +
@@ -2529,7 +2587,10 @@ bool SpirVProcessor::linkInputs(Output& output, const SpirVProcessor& prevStage)
25292587 found = true ;
25302588 if (input.type != out.type || input.patch != out.patch ||
25312589 !inputOutputArraysEqual (out.arrayElements , outputArrays && !out.patch ,
2532- input.arrayElements , inputArrays && !input.patch ))
2590+ input.arrayElements , inputArrays && !input.patch ) ||
2591+ (input.type == Type::Struct && !inputOutputStructsEqual (
2592+ prevStage.structs [out.structIndex ], prevStage.structs ,
2593+ structs[input.structIndex ], structs)))
25332594 {
25342595 output.addMessage (Output::Level::Error, fileName, line, column, false ,
25352596 " linker error: type mismatch when linking input " + input.name +
@@ -2674,7 +2735,7 @@ std::vector<std::uint32_t> SpirVProcessor::process(Strip strip, bool dummyBindin
26742735 if (!inputs[j].autoAssigned )
26752736 continue ;
26762737
2677- if (inputs[j].type == Type::Struct )
2738+ if (inputs[j].block )
26782739 {
26792740 std::uint32_t typeId = structIds[inputs[j].structIndex ];
26802741 auto foundMember = memberLocations.find (typeId);
@@ -2707,7 +2768,7 @@ std::vector<std::uint32_t> SpirVProcessor::process(Strip strip, bool dummyBindin
27072768 if (!outputs[j].autoAssigned )
27082769 continue ;
27092770
2710- if (outputs[j].type == Type::Struct )
2771+ if (outputs[j].block )
27112772 {
27122773 std::uint32_t typeId = structIds[outputs[j].structIndex ];
27132774 auto foundMember = memberLocations.find (typeId);
0 commit comments