@@ -328,6 +328,19 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
328
328
};
329
329
} // namespace
330
330
331
+ static uint32_t GetIntConstAttrArg (ASTContext &astContext, const Expr *expr,
332
+ uint32_t defaultVal = 0 ) {
333
+ if (expr) {
334
+ llvm::APSInt apsInt;
335
+ APValue apValue;
336
+ if (expr->isIntegerConstantExpr (apsInt, astContext))
337
+ return (uint32_t )apsInt.getSExtValue ();
338
+ if (expr->isVulkanSpecConstantExpr (astContext, &apValue) && apValue.isInt ())
339
+ return (uint32_t )apValue.getInt ().getSExtValue ();
340
+ }
341
+ return defaultVal;
342
+ }
343
+
331
344
// ------------------------------------------------------------------------------
332
345
//
333
346
// CGMSHLSLRuntime methods.
@@ -1422,6 +1435,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
1422
1435
}
1423
1436
1424
1437
DiagnosticsEngine &Diags = CGM.getDiags ();
1438
+ ASTContext &astContext = CGM.getTypes ().getContext ();
1425
1439
1426
1440
std::unique_ptr<DxilFunctionProps> funcProps =
1427
1441
llvm::make_unique<DxilFunctionProps>();
@@ -1632,10 +1646,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
1632
1646
1633
1647
// Populate numThreads
1634
1648
if (const HLSLNumThreadsAttr *Attr = FD->getAttr <HLSLNumThreadsAttr>()) {
1635
-
1636
- funcProps->numThreads [0 ] = Attr->getX ();
1637
- funcProps->numThreads [1 ] = Attr->getY ();
1638
- funcProps->numThreads [2 ] = Attr->getZ ();
1649
+ funcProps->numThreads [0 ] = GetIntConstAttrArg (astContext, Attr->getX (), 1 );
1650
+ funcProps->numThreads [1 ] = GetIntConstAttrArg (astContext, Attr->getY (), 1 );
1651
+ funcProps->numThreads [2 ] = GetIntConstAttrArg (astContext, Attr->getZ (), 1 );
1639
1652
1640
1653
if (isEntry && !SM->IsCS () && !SM->IsMS () && !SM->IsAS ()) {
1641
1654
unsigned DiagID = Diags.getCustomDiagID (
@@ -1808,7 +1821,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
1808
1821
1809
1822
if (const auto *pAttr = FD->getAttr <HLSLNodeIdAttr>()) {
1810
1823
funcProps->NodeShaderID .Name = pAttr->getName ().str ();
1811
- funcProps->NodeShaderID .Index = pAttr->getArrayIndex ();
1824
+ funcProps->NodeShaderID .Index =
1825
+ GetIntConstAttrArg (astContext, pAttr->getArrayIndex (), 0 );
1812
1826
} else {
1813
1827
funcProps->NodeShaderID .Name = FD->getName ().str ();
1814
1828
funcProps->NodeShaderID .Index = 0 ;
@@ -1819,20 +1833,28 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
1819
1833
}
1820
1834
if (const auto *pAttr = FD->getAttr <HLSLNodeShareInputOfAttr>()) {
1821
1835
funcProps->NodeShaderSharedInput .Name = pAttr->getName ().str ();
1822
- funcProps->NodeShaderSharedInput .Index = pAttr->getArrayIndex ();
1836
+ funcProps->NodeShaderSharedInput .Index =
1837
+ GetIntConstAttrArg (astContext, pAttr->getArrayIndex (), 0 );
1823
1838
}
1824
1839
if (const auto *pAttr = FD->getAttr <HLSLNodeDispatchGridAttr>()) {
1825
- funcProps->Node .DispatchGrid [0 ] = pAttr->getX ();
1826
- funcProps->Node .DispatchGrid [1 ] = pAttr->getY ();
1827
- funcProps->Node .DispatchGrid [2 ] = pAttr->getZ ();
1840
+ funcProps->Node .DispatchGrid [0 ] =
1841
+ GetIntConstAttrArg (astContext, pAttr->getX (), 1 );
1842
+ funcProps->Node .DispatchGrid [1 ] =
1843
+ GetIntConstAttrArg (astContext, pAttr->getY (), 1 );
1844
+ funcProps->Node .DispatchGrid [2 ] =
1845
+ GetIntConstAttrArg (astContext, pAttr->getZ (), 1 );
1828
1846
}
1829
1847
if (const auto *pAttr = FD->getAttr <HLSLNodeMaxDispatchGridAttr>()) {
1830
- funcProps->Node .MaxDispatchGrid [0 ] = pAttr->getX ();
1831
- funcProps->Node .MaxDispatchGrid [1 ] = pAttr->getY ();
1832
- funcProps->Node .MaxDispatchGrid [2 ] = pAttr->getZ ();
1848
+ funcProps->Node .MaxDispatchGrid [0 ] =
1849
+ GetIntConstAttrArg (astContext, pAttr->getX (), 1 );
1850
+ funcProps->Node .MaxDispatchGrid [1 ] =
1851
+ GetIntConstAttrArg (astContext, pAttr->getY (), 1 );
1852
+ funcProps->Node .MaxDispatchGrid [2 ] =
1853
+ GetIntConstAttrArg (astContext, pAttr->getZ (), 1 );
1833
1854
}
1834
1855
if (const auto *pAttr = FD->getAttr <HLSLNodeMaxRecursionDepthAttr>()) {
1835
- funcProps->Node .MaxRecursionDepth = pAttr->getCount ();
1856
+ funcProps->Node .MaxRecursionDepth =
1857
+ GetIntConstAttrArg (astContext, pAttr->getCount (), 0 );
1836
1858
}
1837
1859
if (!FD->getAttr <HLSLNumThreadsAttr>()) {
1838
1860
// NumThreads wasn't specified.
@@ -2346,8 +2368,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
2346
2368
NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++;
2347
2369
2348
2370
if (parmDecl->hasAttr <HLSLMaxRecordsAttr>()) {
2349
- node.MaxRecords =
2350
- parmDecl->getAttr <HLSLMaxRecordsAttr>()->getMaxCount ();
2371
+ node.MaxRecords = GetIntConstAttrArg (
2372
+ astContext,
2373
+ parmDecl->getAttr <HLSLMaxRecordsAttr>()->getMaxCount (), 1 );
2351
2374
}
2352
2375
if (parmDecl->hasAttr <HLSLGloballyCoherentAttr>())
2353
2376
node.Flags .SetGloballyCoherent ();
@@ -2378,7 +2401,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
2378
2401
// OutputID from attribute
2379
2402
if (const auto *Attr = parmDecl->getAttr <HLSLNodeIdAttr>()) {
2380
2403
node.OutputID .Name = Attr->getName ().str ();
2381
- node.OutputID .Index = Attr->getArrayIndex ();
2404
+ node.OutputID .Index =
2405
+ GetIntConstAttrArg (astContext, Attr->getArrayIndex (), 0 );
2382
2406
} else {
2383
2407
node.OutputID .Name = parmDecl->getName ().str ();
2384
2408
node.OutputID .Index = 0 ;
@@ -2437,7 +2461,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
2437
2461
node.MaxRecordsSharedWith = ix;
2438
2462
}
2439
2463
if (const auto *Attr = parmDecl->getAttr <HLSLMaxRecordsAttr>())
2440
- node.MaxRecords = Attr->getMaxCount ();
2464
+ node.MaxRecords = GetIntConstAttrArg (astContext, Attr->getMaxCount (), 0 );
2441
2465
}
2442
2466
2443
2467
if (inputPatchCount > 1 ) {
0 commit comments