Skip to content

Commit 1f0003f

Browse files
authored
Reapply "[HLSL] Rewrite semantics parsing" (#157718) (#158044)
This is a re-land of #152537 now that #157841 is merged.
1 parent b6674fe commit 1f0003f

22 files changed

+419
-128
lines changed

clang/include/clang/AST/Attr.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,40 @@ class HLSLAnnotationAttr : public InheritableAttr {
232232
}
233233
};
234234

235+
class HLSLSemanticAttr : public HLSLAnnotationAttr {
236+
unsigned SemanticIndex = 0;
237+
LLVM_PREFERRED_TYPE(bool)
238+
unsigned SemanticIndexable : 1;
239+
LLVM_PREFERRED_TYPE(bool)
240+
unsigned SemanticExplicitIndex : 1;
241+
242+
protected:
243+
HLSLSemanticAttr(ASTContext &Context, const AttributeCommonInfo &CommonInfo,
244+
attr::Kind AK, bool IsLateParsed,
245+
bool InheritEvenIfAlreadyPresent, bool SemanticIndexable)
246+
: HLSLAnnotationAttr(Context, CommonInfo, AK, IsLateParsed,
247+
InheritEvenIfAlreadyPresent) {
248+
this->SemanticIndexable = SemanticIndexable;
249+
this->SemanticExplicitIndex = false;
250+
}
251+
252+
public:
253+
bool isSemanticIndexable() const { return SemanticIndexable; }
254+
255+
void setSemanticIndex(unsigned SemanticIndex) {
256+
this->SemanticIndex = SemanticIndex;
257+
this->SemanticExplicitIndex = true;
258+
}
259+
260+
unsigned getSemanticIndex() const { return SemanticIndex; }
261+
262+
// Implement isa/cast/dyncast/etc.
263+
static bool classof(const Attr *A) {
264+
return A->getKind() >= attr::FirstHLSLSemanticAttr &&
265+
A->getKind() <= attr::LastHLSLSemanticAttr;
266+
}
267+
};
268+
235269
/// A parameter attribute which changes the argument-passing ABI rule
236270
/// for the parameter.
237271
class ParameterABIAttr : public InheritableParamAttr {

clang/include/clang/Basic/Attr.td

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,16 @@ class DeclOrStmtAttr : InheritableAttr;
779779
/// An attribute class for HLSL Annotations.
780780
class HLSLAnnotationAttr : InheritableAttr;
781781

782+
class HLSLSemanticAttr<bit Indexable> : HLSLAnnotationAttr {
783+
bit SemanticIndexable = Indexable;
784+
int SemanticIndex = 0;
785+
bit SemanticExplicitIndex = 0;
786+
787+
let Spellings = [];
788+
let Subjects = SubjectList<[ParmVar, Field, Function]>;
789+
let LangOpts = [HLSL];
790+
}
791+
782792
/// A target-specific attribute. This class is meant to be used as a mixin
783793
/// with InheritableAttr or Attr depending on the attribute's needs.
784794
class TargetSpecificAttr<TargetSpec target> {
@@ -4890,27 +4900,6 @@ def HLSLNumThreads: InheritableAttr {
48904900
let Documentation = [NumThreadsDocs];
48914901
}
48924902

4893-
def HLSLSV_GroupThreadID: HLSLAnnotationAttr {
4894-
let Spellings = [HLSLAnnotation<"sv_groupthreadid">];
4895-
let Subjects = SubjectList<[ParmVar, Field]>;
4896-
let LangOpts = [HLSL];
4897-
let Documentation = [HLSLSV_GroupThreadIDDocs];
4898-
}
4899-
4900-
def HLSLSV_GroupID: HLSLAnnotationAttr {
4901-
let Spellings = [HLSLAnnotation<"sv_groupid">];
4902-
let Subjects = SubjectList<[ParmVar, Field]>;
4903-
let LangOpts = [HLSL];
4904-
let Documentation = [HLSLSV_GroupIDDocs];
4905-
}
4906-
4907-
def HLSLSV_GroupIndex: HLSLAnnotationAttr {
4908-
let Spellings = [HLSLAnnotation<"sv_groupindex">];
4909-
let Subjects = SubjectList<[ParmVar, GlobalVar]>;
4910-
let LangOpts = [HLSL];
4911-
let Documentation = [HLSLSV_GroupIndexDocs];
4912-
}
4913-
49144903
def HLSLVkBinding : InheritableAttr {
49154904
let Spellings = [CXX11<"vk", "binding">];
49164905
let Subjects = SubjectList<[HLSLBufferObj, ExternalGlobalVar], ErrorDiag>;
@@ -4969,13 +4958,35 @@ def HLSLResourceBinding: InheritableAttr {
49694958
}];
49704959
}
49714960

4972-
def HLSLSV_Position : HLSLAnnotationAttr {
4973-
let Spellings = [HLSLAnnotation<"sv_position">];
4974-
let Subjects = SubjectList<[ParmVar, Field]>;
4961+
def HLSLUnparsedSemantic : HLSLAnnotationAttr {
4962+
let Spellings = [];
4963+
let Args = [DefaultIntArgument<"Index", 0>,
4964+
DefaultBoolArgument<"ExplicitIndex", 0>];
4965+
let Subjects = SubjectList<[ParmVar, Field, Function]>;
49754966
let LangOpts = [HLSL];
4967+
let Documentation = [InternalOnly];
4968+
}
4969+
4970+
def HLSLSV_Position : HLSLSemanticAttr</* Indexable= */ 1> {
49764971
let Documentation = [HLSLSV_PositionDocs];
49774972
}
49784973

4974+
def HLSLSV_GroupThreadID : HLSLSemanticAttr</* Indexable= */ 0> {
4975+
let Documentation = [HLSLSV_GroupThreadIDDocs];
4976+
}
4977+
4978+
def HLSLSV_GroupID : HLSLSemanticAttr</* Indexable= */ 0> {
4979+
let Documentation = [HLSLSV_GroupIDDocs];
4980+
}
4981+
4982+
def HLSLSV_GroupIndex : HLSLSemanticAttr</* Indexable= */ 0> {
4983+
let Documentation = [HLSLSV_GroupIndexDocs];
4984+
}
4985+
4986+
def HLSLSV_DispatchThreadID : HLSLSemanticAttr</* Indexable= */ 0> {
4987+
let Documentation = [HLSLSV_DispatchThreadIDDocs];
4988+
}
4989+
49794990
def HLSLPackOffset: HLSLAnnotationAttr {
49804991
let Spellings = [HLSLAnnotation<"packoffset">];
49814992
let LangOpts = [HLSL];
@@ -4988,13 +4999,6 @@ def HLSLPackOffset: HLSLAnnotationAttr {
49884999
}];
49895000
}
49905001

4991-
def HLSLSV_DispatchThreadID: HLSLAnnotationAttr {
4992-
let Spellings = [HLSLAnnotation<"sv_dispatchthreadid">];
4993-
let Subjects = SubjectList<[ParmVar, Field]>;
4994-
let LangOpts = [HLSL];
4995-
let Documentation = [HLSLSV_DispatchThreadIDDocs];
4996-
}
4997-
49985002
def HLSLShader : InheritableAttr {
49995003
let Spellings = [Microsoft<"shader">];
50005004
let Subjects = SubjectList<[HLSLEntry]>;

clang/include/clang/Basic/DiagnosticFrontendKinds.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,10 @@ def warn_hlsl_langstd_minimal :
400400
"recommend using %1 instead">,
401401
InGroup<HLSLDXCCompat>;
402402

403+
def err_hlsl_semantic_missing : Error<"semantic annotations must be present "
404+
"for all input and outputs of an entry "
405+
"function or patch constant function">;
406+
403407
// ClangIR frontend errors
404408
def err_cir_to_cir_transform_failed : Error<
405409
"CIR-to-CIR transformation failed">, DefaultFatal;

clang/include/clang/Basic/DiagnosticParseKinds.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,9 +1860,8 @@ def note_max_tokens_total_override : Note<"total token limit set here">;
18601860

18611861
def err_expected_semantic_identifier : Error<
18621862
"expected HLSL Semantic identifier">;
1863-
def err_invalid_declaration_in_hlsl_buffer : Error<
1864-
"invalid declaration inside %select{tbuffer|cbuffer}0">;
1865-
def err_unknown_hlsl_semantic : Error<"unknown HLSL semantic %0">;
1863+
def err_invalid_declaration_in_hlsl_buffer
1864+
: Error<"invalid declaration inside %select{tbuffer|cbuffer}0">;
18661865
def err_hlsl_separate_attr_arg_and_number : Error<"wrong argument format for hlsl attribute, use %0 instead">;
18671866
def ext_hlsl_access_specifiers : ExtWarn<
18681867
"access specifiers are a clang HLSL extension">,

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13121,6 +13121,11 @@ def err_hlsl_duplicate_parameter_modifier : Error<"duplicate parameter modifier
1312113121
def err_hlsl_missing_semantic_annotation : Error<
1312213122
"semantic annotations must be present for all parameters of an entry "
1312313123
"function or patch constant function">;
13124+
def err_hlsl_unknown_semantic : Error<"unknown HLSL semantic %0">;
13125+
def err_hlsl_semantic_output_not_supported
13126+
: Error<"semantic %0 does not support output">;
13127+
def err_hlsl_semantic_indexing_not_supported
13128+
: Error<"semantic %0 does not allow indexing">;
1312413129
def err_hlsl_init_priority_unsupported : Error<
1312513130
"initializer priorities are not supported in HLSL">;
1312613131

clang/include/clang/Parse/Parser.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5188,6 +5188,14 @@ class Parser : public CodeCompletionHandler {
51885188
ParseHLSLAnnotations(Attrs, EndLoc);
51895189
}
51905190

5191+
struct ParsedSemantic {
5192+
StringRef Name = "";
5193+
unsigned Index = 0;
5194+
bool Explicit = false;
5195+
};
5196+
5197+
ParsedSemantic ParseHLSLSemantic();
5198+
51915199
void ParseHLSLAnnotations(ParsedAttributes &Attrs,
51925200
SourceLocation *EndLoc = nullptr,
51935201
bool CouldBeBitField = false);

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "clang/AST/Attr.h"
1818
#include "clang/AST/Type.h"
1919
#include "clang/AST/TypeLoc.h"
20+
#include "clang/Basic/DiagnosticSema.h"
2021
#include "clang/Basic/SourceLocation.h"
2122
#include "clang/Sema/SemaBase.h"
2223
#include "llvm/ADT/SmallVector.h"
@@ -129,6 +130,7 @@ class SemaHLSL : public SemaBase {
129130
bool ActOnUninitializedVarDecl(VarDecl *D);
130131
void ActOnEndOfTranslationUnit(TranslationUnitDecl *TU);
131132
void CheckEntryPoint(FunctionDecl *FD);
133+
bool isSemanticValid(FunctionDecl *FD, DeclaratorDecl *D);
132134
void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
133135
const HLSLAnnotationAttr *AnnotationAttr);
134136
void DiagnoseAttrStageMismatch(
@@ -168,16 +170,31 @@ class SemaHLSL : public SemaBase {
168170
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
169171
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
170172
void handleVkBindingAttr(Decl *D, const ParsedAttr &AL);
171-
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
172-
void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
173-
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
174-
void handleSV_PositionAttr(Decl *D, const ParsedAttr &AL);
175173
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
176174
void handleShaderAttr(Decl *D, const ParsedAttr &AL);
177175
void handleResourceBindingAttr(Decl *D, const ParsedAttr &AL);
178176
void handleParamModifierAttr(Decl *D, const ParsedAttr &AL);
179177
bool handleResourceTypeAttr(QualType T, const ParsedAttr &AL);
180178

179+
template <typename T>
180+
T *createSemanticAttr(const ParsedAttr &AL,
181+
std::optional<unsigned> Location) {
182+
T *Attr = ::new (getASTContext()) T(getASTContext(), AL);
183+
if (Attr->isSemanticIndexable())
184+
Attr->setSemanticIndex(Location ? *Location : 0);
185+
else if (Location.has_value()) {
186+
Diag(Attr->getLocation(), diag::err_hlsl_semantic_indexing_not_supported)
187+
<< Attr->getAttrName()->getName();
188+
return nullptr;
189+
}
190+
191+
return Attr;
192+
}
193+
194+
void diagnoseSystemSemanticAttr(Decl *D, const ParsedAttr &AL,
195+
std::optional<unsigned> Index);
196+
void handleSemanticAttr(Decl *D, const ParsedAttr &AL);
197+
181198
void handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL);
182199

183200
bool CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);

clang/lib/Basic/Attributes.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,12 @@ AttributeCommonInfo::Kind
189189
AttributeCommonInfo::getParsedKind(const IdentifierInfo *Name,
190190
const IdentifierInfo *ScopeName,
191191
Syntax SyntaxUsed) {
192-
return ::getAttrKind(normalizeName(Name, ScopeName, SyntaxUsed), SyntaxUsed);
192+
AttributeCommonInfo::Kind Kind =
193+
::getAttrKind(normalizeName(Name, ScopeName, SyntaxUsed), SyntaxUsed);
194+
if (SyntaxUsed == AS_HLSLAnnotation &&
195+
Kind == AttributeCommonInfo::Kind::UnknownAttribute)
196+
return AttributeCommonInfo::Kind::AT_HLSLUnparsedSemantic;
197+
return Kind;
193198
}
194199

195200
AttributeCommonInfo::AttrArgsInfo

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "clang/AST/RecursiveASTVisitor.h"
2424
#include "clang/AST/Type.h"
2525
#include "clang/Basic/TargetOptions.h"
26+
#include "clang/Frontend/FrontendDiagnostic.h"
2627
#include "llvm/ADT/SmallString.h"
2728
#include "llvm/ADT/SmallVector.h"
2829
#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
@@ -565,47 +566,78 @@ static llvm::Value *createSPIRVBuiltinLoad(IRBuilder<> &B, llvm::Module &M,
565566
return B.CreateLoad(Ty, GV);
566567
}
567568

568-
llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
569-
const ParmVarDecl &D,
570-
llvm::Type *Ty) {
571-
assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
572-
if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
569+
llvm::Value *
570+
CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
571+
const clang::DeclaratorDecl *Decl,
572+
SemanticInfo &ActiveSemantic) {
573+
if (isa<HLSLSV_GroupIndexAttr>(ActiveSemantic.Semantic)) {
573574
llvm::Function *GroupIndex =
574575
CGM.getIntrinsic(getFlattenedThreadIdInGroupIntrinsic());
575576
return B.CreateCall(FunctionCallee(GroupIndex));
576577
}
577-
if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
578+
579+
if (isa<HLSLSV_DispatchThreadIDAttr>(ActiveSemantic.Semantic)) {
578580
llvm::Intrinsic::ID IntrinID = getThreadIdIntrinsic();
579581
llvm::Function *ThreadIDIntrinsic =
580582
llvm::Intrinsic::isOverloaded(IntrinID)
581583
? CGM.getIntrinsic(IntrinID, {CGM.Int32Ty})
582584
: CGM.getIntrinsic(IntrinID);
583-
return buildVectorInput(B, ThreadIDIntrinsic, Ty);
585+
return buildVectorInput(B, ThreadIDIntrinsic, Type);
584586
}
585-
if (D.hasAttr<HLSLSV_GroupThreadIDAttr>()) {
587+
588+
if (isa<HLSLSV_GroupThreadIDAttr>(ActiveSemantic.Semantic)) {
586589
llvm::Intrinsic::ID IntrinID = getGroupThreadIdIntrinsic();
587590
llvm::Function *GroupThreadIDIntrinsic =
588591
llvm::Intrinsic::isOverloaded(IntrinID)
589592
? CGM.getIntrinsic(IntrinID, {CGM.Int32Ty})
590593
: CGM.getIntrinsic(IntrinID);
591-
return buildVectorInput(B, GroupThreadIDIntrinsic, Ty);
594+
return buildVectorInput(B, GroupThreadIDIntrinsic, Type);
592595
}
593-
if (D.hasAttr<HLSLSV_GroupIDAttr>()) {
596+
597+
if (isa<HLSLSV_GroupIDAttr>(ActiveSemantic.Semantic)) {
594598
llvm::Intrinsic::ID IntrinID = getGroupIdIntrinsic();
595599
llvm::Function *GroupIDIntrinsic =
596600
llvm::Intrinsic::isOverloaded(IntrinID)
597601
? CGM.getIntrinsic(IntrinID, {CGM.Int32Ty})
598602
: CGM.getIntrinsic(IntrinID);
599-
return buildVectorInput(B, GroupIDIntrinsic, Ty);
603+
return buildVectorInput(B, GroupIDIntrinsic, Type);
600604
}
601-
if (D.hasAttr<HLSLSV_PositionAttr>()) {
602-
if (getArch() == llvm::Triple::spirv)
603-
return createSPIRVBuiltinLoad(B, CGM.getModule(), Ty, "sv_position",
604-
/* BuiltIn::Position */ 0);
605-
llvm_unreachable("SV_Position semantic not implemented for this target.");
605+
606+
if (HLSLSV_PositionAttr *S =
607+
dyn_cast<HLSLSV_PositionAttr>(ActiveSemantic.Semantic)) {
608+
if (CGM.getTriple().getEnvironment() == Triple::EnvironmentType::Pixel)
609+
return createSPIRVBuiltinLoad(B, CGM.getModule(), Type,
610+
S->getAttrName()->getName(),
611+
/* BuiltIn::FragCoord */ 15);
606612
}
607-
assert(false && "Unhandled parameter attribute");
608-
return nullptr;
613+
614+
llvm_unreachable("non-handled system semantic. FIXME.");
615+
}
616+
617+
llvm::Value *
618+
CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
619+
const clang::DeclaratorDecl *Decl,
620+
SemanticInfo &ActiveSemantic) {
621+
622+
if (!ActiveSemantic.Semantic) {
623+
ActiveSemantic.Semantic = Decl->getAttr<HLSLSemanticAttr>();
624+
if (!ActiveSemantic.Semantic) {
625+
CGM.getDiags().Report(Decl->getInnerLocStart(),
626+
diag::err_hlsl_semantic_missing);
627+
return nullptr;
628+
}
629+
ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
630+
}
631+
632+
return emitSystemSemanticLoad(B, Type, Decl, ActiveSemantic);
633+
}
634+
635+
llvm::Value *
636+
CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
637+
const clang::DeclaratorDecl *Decl,
638+
SemanticInfo &ActiveSemantic) {
639+
assert(!Type->isStructTy());
640+
return handleScalarSemanticLoad(B, Type, Decl, ActiveSemantic);
609641
}
610642

611643
void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
@@ -650,8 +682,10 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
650682
Args.emplace_back(PoisonValue::get(Param.getType()));
651683
continue;
652684
}
685+
653686
const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
654-
Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
687+
SemanticInfo ActiveSemantic = {nullptr, 0};
688+
Args.push_back(handleSemanticLoad(B, Param.getType(), PD, ActiveSemantic));
655689
}
656690

657691
CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);

0 commit comments

Comments
 (0)