From 3e6a556cc9ec759fb2721cd86736400911e3239f Mon Sep 17 00:00:00 2001 From: Alex Sepkowski <5620315+alsepkow@users.noreply.github.com> Date: Mon, 4 Aug 2025 12:35:26 -0700 Subject: [PATCH 1/3] UnaryMathOps and some minor cleanup --- .../unittests/HLSLExec/LongVectorOpTable.xml | 196 +++++++++ .../unittests/HLSLExec/LongVectorTestData.h | 26 +- .../clang/unittests/HLSLExec/LongVectors.cpp | 373 ++++++++++-------- tools/clang/unittests/HLSLExec/LongVectors.h | 169 ++++++-- 4 files changed, 536 insertions(+), 228 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/LongVectorOpTable.xml b/tools/clang/unittests/HLSLExec/LongVectorOpTable.xml index d8ed115f65..f1e84e6c5e 100644 --- a/tools/clang/unittests/HLSLExec/LongVectorOpTable.xml +++ b/tools/clang/unittests/HLSLExec/LongVectorOpTable.xml @@ -779,4 +779,200 @@ SplitDoubleInputValueSet + + + + String + String + String + + + + UnaryMathOpType_Abs + int16 + + + UnaryMathOpType_Sign + int16 + + + + UnaryMathOpType_Abs + int32 + + + UnaryMathOpType_Sign + int32 + + + + UnaryMathOpType_Abs + int64 + + + UnaryMathOpType_Sign + int64 + + + + UnaryMathOpType_Abs + uint16 + + + UnaryMathOpType_Sign + uint16 + + + + UnaryMathOpType_Abs + uint32 + + + UnaryMathOpType_Sign + uint32 + + + + UnaryMathOpType_Abs + uint64 + + + UnaryMathOpType_Sign + uint64 + + + + UnaryMathOpType_Abs + float16 + + + UnaryMathOpType_Ceil + float16 + + + UnaryMathOpType_Exp + float16 + + + UnaryMathOpType_Floor + float16 + + + UnaryMathOpType_Frac + float16 + + + UnaryMathOpType_Log + float16 + + + UnaryMathOpType_Rcp + float16 + + + UnaryMathOpType_Round + float16 + + + UnaryMathOpType_Rsqrt + float16 + + + UnaryMathOpType_Sign + float16 + + + UnaryMathOpType_Sqrt + float16 + + + UnaryMathOpType_Trunc + float16 + + + UnaryMathOpType_Exp2 + float16 + + + UnaryMathOpType_Log10 + float16 + + + UnaryMathOpType_Log2 + float16 + + + + UnaryMathOpType_Abs + float32 + + + UnaryMathOpType_Ceil + float32 + + + UnaryMathOpType_Exp + float32 + + + UnaryMathOpType_Floor + float32 + + + UnaryMathOpType_Frac + float32 + + + UnaryMathOpType_Log + float32 + + + UnaryMathOpType_Rcp + float32 + + + UnaryMathOpType_Round + float32 + + + UnaryMathOpType_Rsqrt + float32 + + + UnaryMathOpType_Sign + float32 + + + UnaryMathOpType_Sqrt + float32 + + + UnaryMathOpType_Trunc + float32 + + + UnaryMathOpType_Exp2 + float32 + + + UnaryMathOpType_Log10 + float32 + + + UnaryMathOpType_Log2 + float32 + + + + UnaryMathOpType_Abs + float64 + + + UnaryMathOpType_Sign + float64 + +
diff --git a/tools/clang/unittests/HLSLExec/LongVectorTestData.h b/tools/clang/unittests/HLSLExec/LongVectorTestData.h index fdd6a00bed..733ce50fb4 100644 --- a/tools/clang/unittests/HLSLExec/LongVectorTestData.h +++ b/tools/clang/unittests/HLSLExec/LongVectorTestData.h @@ -7,6 +7,8 @@ #include #include +namespace LongVector { + // A helper struct because C++ bools are 1 byte and HLSL bools are 4 bytes. // Take int32_t as a constuctor argument and convert it to bool when needed. // Comparisons cast to a bool because we only care if the bool representation is @@ -192,11 +194,11 @@ struct HLSLHalf_t { DirectX::PackedVector::HALF Val = 0; }; -template struct LongVectorTestData { +template struct TestData { static const std::map> Data; }; -template <> struct LongVectorTestData { +template <> struct TestData { inline static const std::map> Data = { {L"DefaultInputValueSet1", {false, true, false, false, false, false, true, true, true, true}}, @@ -205,49 +207,49 @@ template <> struct LongVectorTestData { }; }; -template <> struct LongVectorTestData { +template <> struct TestData { inline static const std::map> Data = { {L"DefaultInputValueSet1", {-6, 1, 7, 3, 8, 4, -3, 8, 8, -2}}, {L"DefaultInputValueSet2", {5, -6, -3, -2, 9, 3, 1, -3, -7, 2}}, }; }; -template <> struct LongVectorTestData { +template <> struct TestData { inline static const std::map> Data = { {L"DefaultInputValueSet1", {-6, 1, 7, 3, 8, 4, -3, 8, 8, -2}}, {L"DefaultInputValueSet2", {5, -6, -3, -2, 9, 3, 1, -3, -7, 2}}, }; }; -template <> struct LongVectorTestData { +template <> struct TestData { inline static const std::map> Data = { {L"DefaultInputValueSet1", {-6, 11, 7, 3, 8, 4, -3, 8, 8, -2}}, {L"DefaultInputValueSet2", {5, -1337, -3, -2, 9, 3, 1, -3, 501, 2}}, }; }; -template <> struct LongVectorTestData { +template <> struct TestData { inline static const std::map> Data = { {L"DefaultInputValueSet1", {1, 699, 3, 1023, 5, 6, 0, 8, 9, 10}}, {L"DefaultInputValueSet2", {2, 111, 3, 4, 5, 9, 21, 8, 9, 10}}, }; }; -template <> struct LongVectorTestData { +template <> struct TestData { inline static const std::map> Data = { {L"DefaultInputValueSet1", {1, 2, 3, 4, 5, 0, 7, 8, 9, 10}}, {L"DefaultInputValueSet2", {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}}, }; }; -template <> struct LongVectorTestData { +template <> struct TestData { inline static const std::map> Data = { {L"DefaultInputValueSet1", {1, 2, 3, 4, 5, 0, 7, 1000, 9, 10}}, {L"DefaultInputValueSet2", {1, 2, 1337, 4, 5, 6, 7, 8, 9, 10}}, }; }; -template <> struct LongVectorTestData { +template <> struct TestData { inline static const std::map> Data = { {L"DefaultInputValueSet1", {-1.0, -1.0, 1.0, -0.01, 1.0, -0.01, 1.0, -0.01, 1.0, -0.01}}, @@ -264,7 +266,7 @@ template <> struct LongVectorTestData { }; }; -template <> struct LongVectorTestData { +template <> struct TestData { inline static const std::map> Data = { {L"DefaultInputValueSet1", {1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0}}, @@ -280,7 +282,7 @@ template <> struct LongVectorTestData { }; }; -template <> struct LongVectorTestData { +template <> struct TestData { inline static const std::map> Data = { {L"DefaultInputValueSet1", {1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0}}, @@ -299,3 +301,5 @@ template <> struct LongVectorTestData { }; #endif // LONGVECTORTESTDATA_H + +}; // namespace LongVector \ No newline at end of file diff --git a/tools/clang/unittests/HLSLExec/LongVectors.cpp b/tools/clang/unittests/HLSLExec/LongVectors.cpp index a3c2c94d00..50f0939645 100644 --- a/tools/clang/unittests/HLSLExec/LongVectors.cpp +++ b/tools/clang/unittests/HLSLExec/LongVectors.cpp @@ -5,53 +5,27 @@ namespace LongVector { template -const OpTypeMetaData & -getLongVectorOpType(const OpTypeMetaData (&Values)[Length], - const std::wstring &OpTypeString) { +const OpTypeMetaData &getOpType(const OpTypeMetaData (&Values)[Length], + const std::wstring &OpTypeString) { for (size_t I = 0; I < Length; I++) { if (Values[I].OpTypeString == OpTypeString) return Values[I]; } - LOG_ERROR_FMT_THROW(L"Invalid LongVectorOpType string: %s", - OpTypeString.c_str()); + LOG_ERROR_FMT_THROW(L"Invalid OpType string: %ls", OpTypeString.c_str()); // We need to return something to satisfy the compiler. We can't annotate // LOG_ERROR_FMT_THROW with [[noreturn]] because the TAEF VERIFY_* macros that // it uses are re-mapped on Unix to not throw exceptions, so they naturally // return. If we hit this point it is a programmer error when implementing a // test. Specifically, an entry for this OpTypeString is missing in the - // static LongVectorOpTypeStringToOpMetaData array. Or something has been + // static OpTypeStringToOpMetaData array. Or something has been // corrupted. Test execution is invalid at this point. Usin std::abort() keeps // the compiler happy about no return path. And LOG_ERROR_FMT_THROW will still // provide a useful error message via gtest logging on Unix systems. std::abort(); } -const OpTypeMetaData & -getBinaryOpType(const std::wstring &OpTypeString) { - return getLongVectorOpType(binaryOpTypeStringToOpMetaData, - OpTypeString); -} - -const OpTypeMetaData & -getUnaryOpType(const std::wstring &OpTypeString) { - return getLongVectorOpType(unaryOpTypeStringToOpMetaData, - OpTypeString); -} - -const OpTypeMetaData & -getAsTypeOpType(const std::wstring &OpTypeString) { - return getLongVectorOpType(asTypeOpTypeStringToOpMetaData, - OpTypeString); -} - -const OpTypeMetaData & -getTrigonometricOpType(const std::wstring &OpTypeString) { - return getLongVectorOpType( - trigonometricOpTypeStringToOpMetaData, OpTypeString); -} - // Helper to fill the test data from the shader buffer based on type. Convenient // to be used when copying HLSL*_t types so we can use the underlying type. template @@ -185,8 +159,13 @@ void fillExpectedVector(VariantVector &ExpectedVector, size_t Count, auto *TypedExpectedValues = std::get_if>(&ExpectedVector); - VERIFY_IS_NOT_NULL(TypedExpectedValues, - L"Expected vector is not of the correct type."); + VERIFY_IS_NOT_NULL( + TypedExpectedValues, + L"Programmer error: Expected vector is not of the correct type."); + + // A TestConfig may be reused for a different vector length. So this is a + // good time to make sure we clear the expected vector. + TypedExpectedValues->clear(); for (size_t Index = 0; Index < Count; ++Index) TypedExpectedValues->push_back(ComputeFn(Index)); @@ -238,9 +217,7 @@ template std::string getHLSLTypeString() { if (std::is_same_v) return "uint64_t"; - std::string ErrStr("getHLSLTypeString() Unsupported type: "); - ErrStr.append(typeid(DataTypeT).name()); - VERIFY_IS_TRUE(false, ErrStr.c_str()); + LOG_ERROR_FMT_THROW(L"Unsupported type: %S", typeid(DataTypeT).name()); return "UnknownType"; } @@ -369,10 +346,24 @@ TEST_F(OpTest, asTypeOpTest) { dispatchTestByDataType(OpTypeMD, DataTypeIn, Handler); } -template -void OpTest::dispatchTestByDataType( - const OpTypeMetaData &OpTypeMd, std::wstring DataType, - TableParameterHandler &Handler) { +TEST_F(OpTest, unaryMathOpTest) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + + const int TableSize = sizeof(UnaryOpParameters) / sizeof(TableParameter); + TableParameterHandler Handler(UnaryOpParameters, TableSize); + + std::wstring DataTypeIn(Handler.GetTableParamByName(L"DataType")->m_str); + std::wstring OpTypeString(Handler.GetTableParamByName(L"OpTypeEnum")->m_str); + + auto OpTypeMD = getUnaryMathOpType(OpTypeString); + dispatchTestByDataType(OpTypeMD, DataTypeIn, Handler); +} + +template +void OpTest::dispatchTestByDataType(const OpTypeMetaData &OpTypeMd, + std::wstring DataType, + TableParameterHandler &Handler) { using namespace WEX::Common; if (DataType == L"bool") @@ -396,8 +387,42 @@ void OpTest::dispatchTestByDataType( else if (DataType == L"float64") dispatchTestByVectorLength(OpTypeMd, Handler); else - VERIFY_FAIL( - String().Format(L"DataType: %ls is not recognized.", DataType.c_str())); + LOG_ERROR_FMT_THROW(L"Unrecognized DataType: %ls for OpType: %ls.", + DataType.c_str(), OpTypeMd.OpTypeString.c_str()); +} + +template <> +void OpTest::dispatchTestByDataType( + const OpTypeMetaData &OpTypeMd, std::wstring DataType, + TableParameterHandler &Handler) { + using namespace WEX::Common; + + // Unary math ops don't support HLSLBool_t. If we included a dispatcher for + // them by allowing the generic dispatchTestByDataType then we would get + // compile errors for a bunch of the templated std lib functions we call to + // compute unary math ops. This is easier and cleaner than guarding against in + // at that point. + if (DataType == L"int16") + dispatchTestByVectorLength(OpTypeMd, Handler); + else if (DataType == L"int32") + dispatchTestByVectorLength(OpTypeMd, Handler); + else if (DataType == L"int64") + dispatchTestByVectorLength(OpTypeMd, Handler); + else if (DataType == L"uint16") + dispatchTestByVectorLength(OpTypeMd, Handler); + else if (DataType == L"uint32") + dispatchTestByVectorLength(OpTypeMd, Handler); + else if (DataType == L"uint64") + dispatchTestByVectorLength(OpTypeMd, Handler); + else if (DataType == L"float16") + dispatchTestByVectorLength(OpTypeMd, Handler); + else if (DataType == L"float32") + dispatchTestByVectorLength(OpTypeMd, Handler); + else if (DataType == L"float64") + dispatchTestByVectorLength(OpTypeMd, Handler); + else + LOG_ERROR_FMT_THROW(L"Invalid UnaryMathOpType DataType: %ls.", + DataType.c_str()); } template <> @@ -419,10 +444,9 @@ void OpTest::dispatchTestByDataType( DataType.c_str()); } -template -void OpTest::dispatchTestByVectorLength( - const OpTypeMetaData &OpTypeMd, - TableParameterHandler &Handler) { +template +void OpTest::dispatchTestByVectorLength(const OpTypeMetaData &OpTypeMd, + TableParameterHandler &Handler) { WEX::TestExecution::SetVerifyOutput verifySettings( WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); @@ -637,26 +661,6 @@ void fillShaderBufferFromLongVectorData( return; } -template -std::string TestConfig::getHLSLInputTypeString() const { - return getHLSLTypeString(); -} - -template -std::string TestConfig::getHLSLOutputTypeString() const { - - // Normal case, output matches input type ( DataTypeT ) - if (auto *Vec = std::get_if>(&ExpectedVector)) - return getHLSLTypeString(); - - // Non normal cases should be handled in a derived TestConfig class. i.e - // TestConfigAsType::getHLSLOutputTypeString() - LOG_ERROR_FMT_THROW( - L"getHLSLOutputTypeString() called with an unsupported op type: %ls", - OpTypeName.c_str()); - return "UnknownType"; -} - // Returns the compiler options string to be used for the shader compilation. // Reference ShaderOpArith.xml and the 'LongVectorOp' shader source to see how // the defines are used in the shader code. @@ -687,10 +691,10 @@ std::string TestConfig::getCompilerOptionsString() const { CompilerOptions << " -DIS_BINARY_VECTOR_OP=1"; CompilerOptions << " -DFUNC="; - CompilerOptions << (Intrinsic ? *Intrinsic : ""); + CompilerOptions << (Intrinsic ? *Intrinsic : " "); } else { // Unary Op CompilerOptions << " -DFUNC="; - CompilerOptions << (Intrinsic ? *Intrinsic : ""); + CompilerOptions << (Intrinsic ? *Intrinsic : " "); // Not used for unary ops, but needs to be a " " for compilation of the // shader after macro expansion. CompilerOptions << " -DOPERAND2= "; @@ -719,68 +723,67 @@ TestConfig::getInputValueSet(size_t ValueSetIndex) const { else if (ValueSetIndex == 2) InputValueSetName = InputValueSetName2; else - VERIFY_FAIL("Invalid ValueSetIndex"); + LOG_ERROR_FMT_THROW(L"Invalid ValueSetIndex: %zu. Expected 1 or 2.", + ValueSetIndex); return getInputValueSetByKey(InputValueSetName); } -// Public version of verifyOutput. Handles logic to dispatch to the correct -// templated verifyOutput based on the expected output type. +template +std::string TestConfig::getHLSLOutputTypeString() const { + // std::visit allows us to dispatch a call to getHLSLTypeString() with the + // the current underlying element type of ExpectedVector. + return std::visit( + [](const auto &Vec) { + using ElementType = typename std::decay_t::value_type; + return getHLSLTypeString(); + }, + ExpectedVector); +} + template bool TestConfig::verifyOutput( const std::shared_ptr &TestResult) { - // First try the most common case where the output datatype matches the input - // datatype (DataTypeT). std::get_if will return a null pointer if the variant - // isn't holding a std::vector - if (auto TypedExpectedValues = - std::get_if>(&ExpectedVector)) { - return verifyOutput(TestResult); - } - // If we get here, its likely a programmer error. DataTypeT is the DataType - // passed in to the TestConfig when its created. The only time the - // ExpectedVector has a diffferent type is for a handful of ops, such as - // casting, where the output type is different from the input type. But proper - // dispatching to verifyOutput with the correct data type is intended to be - // handled by derived classes. Hence, we throw an error here. See callers of - // the private version of verifyOutput for examples of proper usage. - LOG_ERROR_FMT_THROW( - L"verifyOutput() called with an unsupported expected vector type: %ls.", - typeid(ExpectedVector).name()); - return false; + // std::visit allows us to dispatch a call to the private version of + // verifyOutput using a std::vector that matches the type currently held in + // ExpectedVector. This works because ExpectedVector is a std::variant of + // vector types, and the lambda receives the active type at runtime. It's + // important that the TestConfig instance has correctly assigned the expected + // output type to ExpectedVector. By default, this is std::vector, + // but ops like AsTypeOpType must override it. For example, + // AsTypeOpType_AsFloat16 sets ExpectedVector to std::vector. + return std::visit( + [this, &TestResult](const auto &Vec) { + using ElementType = typename std::decay_t::value_type; + return this->verifyOutput(TestResult, Vec); + }, + ExpectedVector); } -// Private version of verifyOutput. Expected to be called internally when we've -// resolved what the expected output type is. +// Private version of verifyOutput. Called by the public version of verifyOutput +// which resolves OutputDataTypeT based on the ExpectedVector type. template template bool TestConfig::verifyOutput( - const std::shared_ptr &TestResult) { + const std::shared_ptr &TestResult, + const std::vector &ExpectedVector) { - if (auto TypedExpectedValues = - std::get_if>(&ExpectedVector)) { - MappedData ShaderOutData; - TestResult->Test->GetReadBackData("OutputVector", &ShaderOutData); + WEX::Logging::Log::Comment(WEX::Common::String().Format( + L"verifyOutput with OpType: %ls ExpectedVector<%S>", OpTypeName.c_str(), + typeid(OutputDataTypeT).name())); - // For most of the ops, the output vector size is the same as the input - // vector size. But some, such as the AsUint_SplitDouble op, have an - // output vector size that is double the input vector size. - const size_t OutputVectorSize = (*TypedExpectedValues).size(); + MappedData ShaderOutData; + TestResult->Test->GetReadBackData("OutputVector", &ShaderOutData); - std::vector ActualValues; - fillLongVectorDataFromShaderBuffer(ShaderOutData, ActualValues, - OutputVectorSize); + const size_t OutputVectorSize = ExpectedVector.size(); - return doVectorsMatch(ActualValues, *TypedExpectedValues, Tolerance, - ValidationType); - } + std::vector ActualValues; + fillLongVectorDataFromShaderBuffer(ShaderOutData, ActualValues, + OutputVectorSize); - // This is the private TestConfig::VerifyOutput. If this is hitting its most - // likely a new test case for a new op type that is misconfigured. - LOG_ERROR_FMT_THROW(L"PRIVATE verifyOutput() called with an unsupported " - L"expected vector type: %ls.", - typeid(ExpectedVector).name()); - return false; + return doVectorsMatch(ActualValues, ExpectedVector, Tolerance, + ValidationType); } // Generic computeExpectedValues for Unary ops. Derived classes override @@ -851,7 +854,7 @@ TestConfigAsType::TestConfigAsType( BasicOpType = BasicOpType_Binary; break; default: - VERIFY_FAIL("Invalid AsTypeOpType"); + LOG_ERROR_FMT_THROW(L"Unsupported AsTypeOpType: %ls", OpTypeName.c_str()); } } @@ -930,66 +933,6 @@ void TestConfigAsType::computeExpectedValues( }); } -template -std::string TestConfigAsType::getHLSLOutputTypeString() const { - - switch (OpType) { - case AsTypeOpType_AsFloat16: - return getHLSLTypeString(); - case AsTypeOpType_AsFloat: - return getHLSLTypeString(); - case AsTypeOpType_AsInt: - return getHLSLTypeString(); - case AsTypeOpType_AsInt16: - return getHLSLTypeString(); - case AsTypeOpType_AsUint: - return getHLSLTypeString(); - case AsTypeOpType_AsUint_SplitDouble: - return getHLSLTypeString(); - case AsTypeOpType_AsUint16: - return getHLSLTypeString(); - case AsTypeOpType_AsDouble: - return getHLSLTypeString(); - default: - LOG_ERROR_FMT_THROW( - L"getHLSLOutputTypeString() called with an unsupported op type: %ls", - OpTypeName.c_str()); - return "UnknownType"; - } -} - -// Override verifyOutput for AsTypeOpType as the output type for these ops -// doesn't match the input type of the config. Calls a private templated version -// of verifyOutput with the correct data type based on the op. -template -bool TestConfigAsType::verifyOutput( - const std::shared_ptr &TestResult) { - - switch (OpType) { - case AsTypeOpType_AsFloat: - return TestConfig::verifyOutput(TestResult); - case AsTypeOpType_AsFloat16: - return TestConfig::verifyOutput(TestResult); - case AsTypeOpType_AsInt: - return TestConfig::verifyOutput(TestResult); - case AsTypeOpType_AsInt16: - return TestConfig::verifyOutput(TestResult); - case AsTypeOpType_AsUint: - return TestConfig::verifyOutput(TestResult); - case AsTypeOpType_AsUint_SplitDouble: - return TestConfig::verifyOutput(TestResult); - case AsTypeOpType_AsUint16: - return TestConfig::verifyOutput(TestResult); - case AsTypeOpType_AsDouble: - return TestConfig::verifyOutput(TestResult); - default: - LOG_ERROR_FMT_THROW( - L"verifyOutput() called with an unsupported AsTypeOpType: %ls", - OpTypeName.c_str()); - return false; - } -} - template TestConfigTrigonometric::TestConfigTrigonometric( const OpTypeMetaData &OpTypeMd) @@ -1056,7 +999,7 @@ TestConfigUnary::TestConfigUnary( SpecialDefines = " -DFUNC_INITIALIZE=1"; break; default: - VERIFY_FAIL("Invalid UnaryOpType"); + LOG_ERROR_FMT_THROW(L"Unsupported UnaryOpType: %ls", OpTypeName.c_str()); } } @@ -1104,7 +1047,7 @@ TestConfigBinary::TestConfigBinary( BasicOpType = BasicOpType_Binary; break; default: - VERIFY_FAIL("Invalid BinaryOpType"); + LOG_ERROR_FMT_THROW(L"Invalid BinaryOpType: %ls", OpTypeName.c_str()); } } @@ -1149,4 +1092,84 @@ TestConfigBinary::computeExpectedValue(const DataTypeT &A, } } -}; // namespace LongVector \ No newline at end of file +template +TestConfigUnaryMath::TestConfigUnaryMath( + const OpTypeMetaData &OpTypeMd) + : TestConfig(OpTypeMd), OpType(OpTypeMd.OpType) { + + BasicOpType = BasicOpType_Unary; + + if (isFloatingPointType()) { + Tolerance = 1; + ValidationType = ValidationType_Ulp; + } + + if (OpType == UnaryMathOpType_Sign) + ExpectedVector = std::vector{}; +} + +template +DataTypeT +TestConfigUnaryMath::computeExpectedValue(const DataTypeT &A) const { + + // A bunch of the std match functions here are wrapped in () to avoid + // collisions with the macro defitions for various functions in windows.h + switch (OpType) { + case UnaryMathOpType_Abs: + return abs(A); + case UnaryMathOpType_Ceil: + return (std::ceil)(A); + case UnaryMathOpType_Floor: + return (std::floor)(A); + case UnaryMathOpType_Trunc: + return (std::trunc)(A); + case UnaryMathOpType_Round: + return (std::round)(A); + case UnaryMathOpType_Frac: + // std::frac is not a standard C++ function, but we can implement it as + return A - DataTypeT((std::floor)(A)); + case UnaryMathOpType_Sqrt: + return (std::sqrt)(A); + case UnaryMathOpType_Rsqrt: + // std::rsqrt is not a standard C++ function, but we can implement it as + return DataTypeT(1.0) / DataTypeT((std::sqrt)(A)); + case UnaryMathOpType_Exp: + return (std::exp)(A); + case UnaryMathOpType_Exp2: + return (std::exp2)(A); + case UnaryMathOpType_Log: + return (std::log)(A); + case UnaryMathOpType_Log2: + return (std::log2)(A); + case UnaryMathOpType_Log10: + return (std::log10)(A); + case UnaryMathOpType_Rcp: + // std::.rcp is not a standard C++ function, but we can implement it as + return DataTypeT(1.0) / A; + default: + LOG_ERROR_FMT_THROW(L"computeExpectedValue(const DataTypeT &A)" + L"called on an unrecognized unary math op: %ls", + OpTypeName.c_str()); + return DataTypeT(); + } +} + +template +void TestConfigUnaryMath::computeExpectedValues( + const std::vector &InputVector1) { + + if (OpType == UnaryMathOpType_Sign) { + fillExpectedVector( + ExpectedVector, InputVector1.size(), + // The sign function returns an int32_t value, so we handle here instead + // of in the templated computeExpectedValue function. + [&](size_t Index) { return sign(InputVector1[Index]); }); + return; + } + + fillExpectedVector( + ExpectedVector, InputVector1.size(), + [&](size_t Index) { return computeExpectedValue(InputVector1[Index]); }); +} + +}; // namespace LongVector diff --git a/tools/clang/unittests/HLSLExec/LongVectors.h b/tools/clang/unittests/HLSLExec/LongVectors.h index b7a37d09c1..6ba6961a0f 100644 --- a/tools/clang/unittests/HLSLExec/LongVectors.h +++ b/tools/clang/unittests/HLSLExec/LongVectors.h @@ -45,14 +45,6 @@ using VariantVector = std::vector, std::vector, std::vector>; -// A helper struct to clear a VariantVector using std::visit. -// Example usage: std::visit(ClearVariantVector{}, MyVariantVector); -struct ClearVariantVector { - template void operator()(std::vector &vec) const { - vec.clear(); - } -}; - template void fillShaderBufferFromLongVectorData(std::vector &ShaderBuffer, const std::vector &TestData); @@ -90,17 +82,16 @@ template std::string getHLSLTypeString(); // expansion. May be empty. See getCompilerOptionsString() in LongVector.cpp and // 'LongVectorOp' entry ShaderOpArith.xml. Expands to things like '+', '-', // '*', etc. -template struct OpTypeMetaData { +template struct OpTypeMetaData { std::wstring OpTypeString; - LongVectorOpTypeT OpType; + OpTypeT OpType; std::optional Intrinsic = std::nullopt; std::optional Operator = std::nullopt; }; template -const OpTypeMetaData & -getLongVectorOpType(const OpTypeMetaData (&Values)[Length], - const std::wstring &OpTypeString); +const OpTypeMetaData &getOpType(const OpTypeMetaData (&Values)[Length], + const std::wstring &OpTypeString); enum ValidationType { ValidationType_Epsilon, @@ -159,7 +150,9 @@ static_assert(_countof(binaryOpTypeStringToOpMetaData) == "add a new enum value?"); const OpTypeMetaData & -getBinaryOpType(const std::wstring &OpTypeString); +getBinaryOpType(const std::wstring &OpTypeString) { + return getOpType(binaryOpTypeStringToOpMetaData, OpTypeString); +} enum UnaryOpType { UnaryOpType_Initialize, UnaryOpType_EnumValueCount }; @@ -173,7 +166,9 @@ static_assert(_countof(unaryOpTypeStringToOpMetaData) == "a new enum value?"); const OpTypeMetaData & -getUnaryOpType(const std::wstring &OpTypeString); +getUnaryOpType(const std::wstring &OpTypeString) { + return getOpType(unaryOpTypeStringToOpMetaData, OpTypeString); +} enum AsTypeOpType { AsTypeOpType_AsFloat, @@ -205,7 +200,9 @@ static_assert(_countof(asTypeOpTypeStringToOpMetaData) == "a new enum value?"); const OpTypeMetaData & -getAsTypeOpType(const std::wstring &OpTypeString); +getAsTypeOpType(const std::wstring &OpTypeString) { + return getOpType(asTypeOpTypeStringToOpMetaData, OpTypeString); +} enum TrigonometricOpType { TrigonometricOpType_Acos, @@ -240,7 +237,59 @@ static_assert( "a new enum value?"); const OpTypeMetaData & -getTrigonometricOpType(const std::wstring &OpTypeString); +getTrigonometricOpType(const std::wstring &OpTypeString) { + return getOpType(trigonometricOpTypeStringToOpMetaData, + OpTypeString); +} + +enum UnaryMathOpType { + UnaryMathOpType_Abs, + UnaryMathOpType_Sign, + UnaryMathOpType_Ceil, + UnaryMathOpType_Floor, + UnaryMathOpType_Trunc, + UnaryMathOpType_Round, + UnaryMathOpType_Frac, + UnaryMathOpType_Sqrt, + UnaryMathOpType_Rsqrt, + UnaryMathOpType_Exp, + UnaryMathOpType_Exp2, + UnaryMathOpType_Log, + UnaryMathOpType_Log2, + UnaryMathOpType_Log10, + UnaryMathOpType_Rcp, + UnaryMathOpType_EnumValueCount +}; + +static const OpTypeMetaData + unaryMathOpTypeStringToOpMetaData[] = { + {L"UnaryMathOpType_Abs", UnaryMathOpType_Abs, "abs"}, + {L"UnaryMathOpType_Sign", UnaryMathOpType_Sign, "sign"}, + {L"UnaryMathOpType_Ceil", UnaryMathOpType_Ceil, "ceil"}, + {L"UnaryMathOpType_Floor", UnaryMathOpType_Floor, "floor"}, + {L"UnaryMathOpType_Trunc", UnaryMathOpType_Trunc, "trunc"}, + {L"UnaryMathOpType_Round", UnaryMathOpType_Round, "round"}, + {L"UnaryMathOpType_Frac", UnaryMathOpType_Frac, "frac"}, + {L"UnaryMathOpType_Sqrt", UnaryMathOpType_Sqrt, "sqrt"}, + {L"UnaryMathOpType_Rsqrt", UnaryMathOpType_Rsqrt, "rsqrt"}, + {L"UnaryMathOpType_Exp", UnaryMathOpType_Exp, "exp"}, + {L"UnaryMathOpType_Exp2", UnaryMathOpType_Exp2, "exp2"}, + {L"UnaryMathOpType_Log", UnaryMathOpType_Log, "log"}, + {L"UnaryMathOpType_Log2", UnaryMathOpType_Log2, "log2"}, + {L"UnaryMathOpType_Log10", UnaryMathOpType_Log10, "log10"}, + {L"UnaryMathOpType_Rcp", UnaryMathOpType_Rcp, "rcp"}, +}; + +static_assert(_countof(unaryMathOpTypeStringToOpMetaData) == + UnaryMathOpType_EnumValueCount, + "unaryMathOpTypeStringToOpMetaData size mismatch. Did you add " + "a new enum value?"); + +const OpTypeMetaData & +getUnaryMathOpType(const std::wstring &OpTypeString) { + return getOpType(unaryMathOpTypeStringToOpMetaData, + OpTypeString); +} template std::vector getInputValueSetByKey(const std::wstring &Key, @@ -248,7 +297,7 @@ std::vector getInputValueSetByKey(const std::wstring &Key, if (LogKey) WEX::Logging::Log::Comment( WEX::Common::String().Format(L"Using Value Set Key: %s", Key.c_str())); - return std::vector(LongVectorTestData::Data.at(Key)); + return std::vector(TestData::Data.at(Key)); } // The TAEF test class. @@ -280,8 +329,13 @@ class OpTest { L"Table:LongVectorOpTable.xml#AsTypeOpTable") END_TEST_METHOD() - template - void dispatchTestByDataType(const OpTypeMetaData &OpTypeMD, + BEGIN_TEST_METHOD(unaryMathOpTest) + TEST_METHOD_PROPERTY(L"DataSource", + L"Table:LongVectorOpTable.xml#UnaryMathOpTable") + END_TEST_METHOD() + + template + void dispatchTestByDataType(const OpTypeMetaData &OpTypeMD, std::wstring DataType, TableParameterHandler &Handler); @@ -290,10 +344,14 @@ class OpTest { dispatchTestByDataType(const OpTypeMetaData &OpTypeMD, std::wstring DataType, TableParameterHandler &Handler); - template - void - dispatchTestByVectorLength(const OpTypeMetaData &OpTypeMD, - TableParameterHandler &Handler); + template <> + void dispatchTestByDataType(const OpTypeMetaData &OpTypeMD, + std::wstring DataType, + TableParameterHandler &Handler); + + template + void dispatchTestByVectorLength(const OpTypeMetaData &OpTypeMD, + TableParameterHandler &Handler); template void testBaseMethod(std::unique_ptr> &TestConfig); @@ -342,8 +400,10 @@ template class TestConfig { bool isScalarOp() const { return BasicOpType == BasicOpType_ScalarBinary; } // Helpers to get the hlsl type as a string for a given C++ type. - std::string getHLSLInputTypeString() const; - virtual std::string getHLSLOutputTypeString() const; + std::string getHLSLInputTypeString() const { + return getHLSLTypeString(); + } + std::string getHLSLOutputTypeString() const; virtual void computeExpectedValues(const std::vector &InputVector1); @@ -362,10 +422,6 @@ template class TestConfig { } void setLengthToTest(size_t LengthToTest) { - // Make sure we clear the expected vector when setting a new length. - // The TestConfig may be getting reused. - std::visit(ClearVariantVector{}, ExpectedVector); - this->LengthToTest = LengthToTest; } @@ -386,12 +442,17 @@ template class TestConfig { std::string getCompilerOptionsString() const; - virtual bool - verifyOutput(const std::shared_ptr &TestResult); + bool verifyOutput(const std::shared_ptr &TestResult); private: std::vector getInputValueSet(size_t ValueSetIndex) const; + // Templated version to be used when the output data type does not match the + // input data type. + template + bool verifyOutput(const std::shared_ptr &TestResult, + const std::vector &ExpectedVector); + // The input value sets are used to fill the shader buffer. std::wstring InputValueSetName1 = L"DefaultInputValueSet1"; std::wstring InputValueSetName2 = L"DefaultInputValueSet2"; @@ -401,16 +462,11 @@ template class TestConfig { protected: // Prevent instances of TestConfig from being created directly. Want to force // a derived class to be used for creation. - template - TestConfig(const OpTypeMetaData &OpTypeMd) + template + TestConfig(const OpTypeMetaData &OpTypeMd) : OpTypeName(OpTypeMd.OpTypeString), Intrinsic(OpTypeMd.Intrinsic), Operator(OpTypeMd.Operator) {} - // Templated version to be used when the output data type does not match the - // input data type. - template - bool verifyOutput(const std::shared_ptr &TestResult); - // The appropriate computeExpectedValue should be implemented in derived // classes. Impelemented as virtual here to prevent requiring all derived // classes from needing to implement. The OS builds disable RTTI, so using @@ -457,9 +513,6 @@ class TestConfigAsType : public TestConfig { void computeExpectedValues(const std::vector &InputVector1, const std::vector &InputVector2) override; - std::string getHLSLOutputTypeString() const override; - bool verifyOutput( - const std::shared_ptr &TestResult) override; private: template @@ -633,6 +686,32 @@ class TestConfigBinary : public TestConfig { } }; +template +class TestConfigUnaryMath : public TestConfig { +public: + TestConfigUnaryMath(const OpTypeMetaData &OpTypeMd); + DataTypeT computeExpectedValue(const DataTypeT &A) const override; + void + computeExpectedValues(const std::vector &InputVector1) override; + +private: + UnaryMathOpType OpType = UnaryMathOpType_EnumValueCount; + + template int32_t sign(const DataTypeT &A) const { + // Return 1 for positive, -1 for negative, 0 for zero. + // Wrap comparison operands in DataTypeInT constructor to make sure + // we are comparing the same type. + return A > DataTypeT(0) ? 1 : A < DataTypeT(0) ? -1 : 0; + } + + template DataTypeT abs(const DataTypeT &A) const { + if constexpr (std::is_unsigned::value) + return DataTypeT(A); + else + return (std::abs)(A); + } +}; + template std::unique_ptr> makeTestConfig(const OpTypeMetaData &OpTypeMetaData) { @@ -657,6 +736,12 @@ makeTestConfig(const OpTypeMetaData &OpTypeMetaData) { return std::make_unique>(OpTypeMetaData); } +template +std::unique_ptr> +makeTestConfig(const OpTypeMetaData &OpTypeMetaData) { + return std::make_unique>(OpTypeMetaData); +} + }; // namespace LongVector #endif // LONGVECTORS_H From edf380ae4e3eb5e91209b5db28d780413c1d91a2 Mon Sep 17 00:00:00 2001 From: Alex Sepkowski <5620315+alsepkow@users.noreply.github.com> Date: Mon, 11 Aug 2025 15:30:25 -0700 Subject: [PATCH 2/3] Fix location of endif --- tools/clang/unittests/HLSLExec/LongVectorTestData.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/LongVectorTestData.h b/tools/clang/unittests/HLSLExec/LongVectorTestData.h index 733ce50fb4..60d677847d 100644 --- a/tools/clang/unittests/HLSLExec/LongVectorTestData.h +++ b/tools/clang/unittests/HLSLExec/LongVectorTestData.h @@ -300,6 +300,6 @@ template <> struct TestData { }; }; -#endif // LONGVECTORTESTDATA_H +}; // namespace LongVector -}; // namespace LongVector \ No newline at end of file +#endif // LONGVECTORTESTDATA_H From bbe8de08685e5f8c9cced61a45d0c276dd4e2af7 Mon Sep 17 00:00:00 2001 From: Alex Sepkowski <5620315+alsepkow@users.noreply.github.com> Date: Mon, 11 Aug 2025 17:15:07 -0700 Subject: [PATCH 3/3] Switch cases with std::wstring --- .../clang/unittests/HLSLExec/LongVectors.cpp | 65 +++++++++++++------ tools/clang/unittests/HLSLExec/LongVectors.h | 13 ++++ 2 files changed, 57 insertions(+), 21 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/LongVectors.cpp b/tools/clang/unittests/HLSLExec/LongVectors.cpp index 50f0939645..da5a927eb9 100644 --- a/tools/clang/unittests/HLSLExec/LongVectors.cpp +++ b/tools/clang/unittests/HLSLExec/LongVectors.cpp @@ -366,29 +366,41 @@ void OpTest::dispatchTestByDataType(const OpTypeMetaData &OpTypeMd, TableParameterHandler &Handler) { using namespace WEX::Common; - if (DataType == L"bool") + switch (Hash_djb2a(DataType)) { + case Hash_djb2a(L"bool"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"int16") + break; + case Hash_djb2a(L"int16"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"int32") + break; + case Hash_djb2a(L"int32"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"int64") + break; + case Hash_djb2a(L"int64"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"uint16") + break; + case Hash_djb2a(L"uint16"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"uint32") + break; + case Hash_djb2a(L"uint32"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"uint64") + break; + case Hash_djb2a(L"uint64"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"float16") + break; + case Hash_djb2a(L"float16"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"float32") + break; + case Hash_djb2a(L"float32"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"float64") + break; + case Hash_djb2a(L"float64"): dispatchTestByVectorLength(OpTypeMd, Handler); - else + break; + default: LOG_ERROR_FMT_THROW(L"Unrecognized DataType: %ls for OpType: %ls.", DataType.c_str(), OpTypeMd.OpTypeString.c_str()); + } } template <> @@ -402,27 +414,38 @@ void OpTest::dispatchTestByDataType( // compile errors for a bunch of the templated std lib functions we call to // compute unary math ops. This is easier and cleaner than guarding against in // at that point. - if (DataType == L"int16") + switch (Hash_djb2a(DataType)) { + case Hash_djb2a(L"int16"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"int32") + break; + case Hash_djb2a(L"int32"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"int64") + break; + case Hash_djb2a(L"int64"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"uint16") + break; + case Hash_djb2a(L"uint16"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"uint32") + break; + case Hash_djb2a(L"uint32"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"uint64") + break; + case Hash_djb2a(L"uint64"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"float16") + break; + case Hash_djb2a(L"float16"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"float32") + break; + case Hash_djb2a(L"float32"): dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"float64") + break; + case Hash_djb2a(L"float64"): dispatchTestByVectorLength(OpTypeMd, Handler); - else + break; + default: LOG_ERROR_FMT_THROW(L"Invalid UnaryMathOpType DataType: %ls.", DataType.c_str()); + } } template <> diff --git a/tools/clang/unittests/HLSLExec/LongVectors.h b/tools/clang/unittests/HLSLExec/LongVectors.h index 6ba6961a0f..e9f433661f 100644 --- a/tools/clang/unittests/HLSLExec/LongVectors.h +++ b/tools/clang/unittests/HLSLExec/LongVectors.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -24,6 +25,18 @@ namespace LongVector { +// Used to compute the hash of a std::wstring at compile time. Gives us a way to +// create switch statements with a std::wstring. +// Note: Because this is evaluated at compile time the compiler detects hash +// collisions via an duplicate case statement error. +inline constexpr auto Hash_djb2a(const std::wstring_view String) { + unsigned long Hash{1337}; + for (wchar_t c : String) { + Hash = ((Hash << 5) + Hash) ^ static_cast(c); + } + return Hash; +} + // We don't have std::bit_cast in C++17, so we define our own version. template typename std::enable_if