From 3e6a556cc9ec759fb2721cd86736400911e3239f Mon Sep 17 00:00:00 2001
From: Alex Sepkowski <5620315+alsepkow@users.noreply.github.com>
Date: Mon, 4 Aug 2025 12:35:26 -0700
Subject: [PATCH 1/3] UnaryMathOps and some minor cleanup
---
.../unittests/HLSLExec/LongVectorOpTable.xml | 196 +++++++++
.../unittests/HLSLExec/LongVectorTestData.h | 26 +-
.../clang/unittests/HLSLExec/LongVectors.cpp | 373 ++++++++++--------
tools/clang/unittests/HLSLExec/LongVectors.h | 169 ++++++--
4 files changed, 536 insertions(+), 228 deletions(-)
diff --git a/tools/clang/unittests/HLSLExec/LongVectorOpTable.xml b/tools/clang/unittests/HLSLExec/LongVectorOpTable.xml
index d8ed115f65..f1e84e6c5e 100644
--- a/tools/clang/unittests/HLSLExec/LongVectorOpTable.xml
+++ b/tools/clang/unittests/HLSLExec/LongVectorOpTable.xml
@@ -779,4 +779,200 @@
SplitDoubleInputValueSet
+
+
+
+ String
+ String
+ String
+
+
+
+ UnaryMathOpType_Abs
+ int16
+
+
+ UnaryMathOpType_Sign
+ int16
+
+
+
+ UnaryMathOpType_Abs
+ int32
+
+
+ UnaryMathOpType_Sign
+ int32
+
+
+
+ UnaryMathOpType_Abs
+ int64
+
+
+ UnaryMathOpType_Sign
+ int64
+
+
+
+ UnaryMathOpType_Abs
+ uint16
+
+
+ UnaryMathOpType_Sign
+ uint16
+
+
+
+ UnaryMathOpType_Abs
+ uint32
+
+
+ UnaryMathOpType_Sign
+ uint32
+
+
+
+ UnaryMathOpType_Abs
+ uint64
+
+
+ UnaryMathOpType_Sign
+ uint64
+
+
+
+ UnaryMathOpType_Abs
+ float16
+
+
+ UnaryMathOpType_Ceil
+ float16
+
+
+ UnaryMathOpType_Exp
+ float16
+
+
+ UnaryMathOpType_Floor
+ float16
+
+
+ UnaryMathOpType_Frac
+ float16
+
+
+ UnaryMathOpType_Log
+ float16
+
+
+ UnaryMathOpType_Rcp
+ float16
+
+
+ UnaryMathOpType_Round
+ float16
+
+
+ UnaryMathOpType_Rsqrt
+ float16
+
+
+ UnaryMathOpType_Sign
+ float16
+
+
+ UnaryMathOpType_Sqrt
+ float16
+
+
+ UnaryMathOpType_Trunc
+ float16
+
+
+ UnaryMathOpType_Exp2
+ float16
+
+
+ UnaryMathOpType_Log10
+ float16
+
+
+ UnaryMathOpType_Log2
+ float16
+
+
+
+ UnaryMathOpType_Abs
+ float32
+
+
+ UnaryMathOpType_Ceil
+ float32
+
+
+ UnaryMathOpType_Exp
+ float32
+
+
+ UnaryMathOpType_Floor
+ float32
+
+
+ UnaryMathOpType_Frac
+ float32
+
+
+ UnaryMathOpType_Log
+ float32
+
+
+ UnaryMathOpType_Rcp
+ float32
+
+
+ UnaryMathOpType_Round
+ float32
+
+
+ UnaryMathOpType_Rsqrt
+ float32
+
+
+ UnaryMathOpType_Sign
+ float32
+
+
+ UnaryMathOpType_Sqrt
+ float32
+
+
+ UnaryMathOpType_Trunc
+ float32
+
+
+ UnaryMathOpType_Exp2
+ float32
+
+
+ UnaryMathOpType_Log10
+ float32
+
+
+ UnaryMathOpType_Log2
+ float32
+
+
+
+ UnaryMathOpType_Abs
+ float64
+
+
+ UnaryMathOpType_Sign
+ float64
+
+
diff --git a/tools/clang/unittests/HLSLExec/LongVectorTestData.h b/tools/clang/unittests/HLSLExec/LongVectorTestData.h
index fdd6a00bed..733ce50fb4 100644
--- a/tools/clang/unittests/HLSLExec/LongVectorTestData.h
+++ b/tools/clang/unittests/HLSLExec/LongVectorTestData.h
@@ -7,6 +7,8 @@
#include
#include
+namespace LongVector {
+
// A helper struct because C++ bools are 1 byte and HLSL bools are 4 bytes.
// Take int32_t as a constuctor argument and convert it to bool when needed.
// Comparisons cast to a bool because we only care if the bool representation is
@@ -192,11 +194,11 @@ struct HLSLHalf_t {
DirectX::PackedVector::HALF Val = 0;
};
-template struct LongVectorTestData {
+template struct TestData {
static const std::map> Data;
};
-template <> struct LongVectorTestData {
+template <> struct TestData {
inline static const std::map> Data = {
{L"DefaultInputValueSet1",
{false, true, false, false, false, false, true, true, true, true}},
@@ -205,49 +207,49 @@ template <> struct LongVectorTestData {
};
};
-template <> struct LongVectorTestData {
+template <> struct TestData {
inline static const std::map> Data = {
{L"DefaultInputValueSet1", {-6, 1, 7, 3, 8, 4, -3, 8, 8, -2}},
{L"DefaultInputValueSet2", {5, -6, -3, -2, 9, 3, 1, -3, -7, 2}},
};
};
-template <> struct LongVectorTestData {
+template <> struct TestData {
inline static const std::map> Data = {
{L"DefaultInputValueSet1", {-6, 1, 7, 3, 8, 4, -3, 8, 8, -2}},
{L"DefaultInputValueSet2", {5, -6, -3, -2, 9, 3, 1, -3, -7, 2}},
};
};
-template <> struct LongVectorTestData {
+template <> struct TestData {
inline static const std::map> Data = {
{L"DefaultInputValueSet1", {-6, 11, 7, 3, 8, 4, -3, 8, 8, -2}},
{L"DefaultInputValueSet2", {5, -1337, -3, -2, 9, 3, 1, -3, 501, 2}},
};
};
-template <> struct LongVectorTestData {
+template <> struct TestData {
inline static const std::map> Data = {
{L"DefaultInputValueSet1", {1, 699, 3, 1023, 5, 6, 0, 8, 9, 10}},
{L"DefaultInputValueSet2", {2, 111, 3, 4, 5, 9, 21, 8, 9, 10}},
};
};
-template <> struct LongVectorTestData {
+template <> struct TestData {
inline static const std::map> Data = {
{L"DefaultInputValueSet1", {1, 2, 3, 4, 5, 0, 7, 8, 9, 10}},
{L"DefaultInputValueSet2", {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}},
};
};
-template <> struct LongVectorTestData {
+template <> struct TestData {
inline static const std::map> Data = {
{L"DefaultInputValueSet1", {1, 2, 3, 4, 5, 0, 7, 1000, 9, 10}},
{L"DefaultInputValueSet2", {1, 2, 1337, 4, 5, 6, 7, 8, 9, 10}},
};
};
-template <> struct LongVectorTestData {
+template <> struct TestData {
inline static const std::map> Data = {
{L"DefaultInputValueSet1",
{-1.0, -1.0, 1.0, -0.01, 1.0, -0.01, 1.0, -0.01, 1.0, -0.01}},
@@ -264,7 +266,7 @@ template <> struct LongVectorTestData {
};
};
-template <> struct LongVectorTestData {
+template <> struct TestData {
inline static const std::map> Data = {
{L"DefaultInputValueSet1",
{1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0}},
@@ -280,7 +282,7 @@ template <> struct LongVectorTestData {
};
};
-template <> struct LongVectorTestData {
+template <> struct TestData {
inline static const std::map> Data = {
{L"DefaultInputValueSet1",
{1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0}},
@@ -299,3 +301,5 @@ template <> struct LongVectorTestData {
};
#endif // LONGVECTORTESTDATA_H
+
+}; // namespace LongVector
\ No newline at end of file
diff --git a/tools/clang/unittests/HLSLExec/LongVectors.cpp b/tools/clang/unittests/HLSLExec/LongVectors.cpp
index a3c2c94d00..50f0939645 100644
--- a/tools/clang/unittests/HLSLExec/LongVectors.cpp
+++ b/tools/clang/unittests/HLSLExec/LongVectors.cpp
@@ -5,53 +5,27 @@
namespace LongVector {
template
-const OpTypeMetaData &
-getLongVectorOpType(const OpTypeMetaData (&Values)[Length],
- const std::wstring &OpTypeString) {
+const OpTypeMetaData &getOpType(const OpTypeMetaData (&Values)[Length],
+ const std::wstring &OpTypeString) {
for (size_t I = 0; I < Length; I++) {
if (Values[I].OpTypeString == OpTypeString)
return Values[I];
}
- LOG_ERROR_FMT_THROW(L"Invalid LongVectorOpType string: %s",
- OpTypeString.c_str());
+ LOG_ERROR_FMT_THROW(L"Invalid OpType string: %ls", OpTypeString.c_str());
// We need to return something to satisfy the compiler. We can't annotate
// LOG_ERROR_FMT_THROW with [[noreturn]] because the TAEF VERIFY_* macros that
// it uses are re-mapped on Unix to not throw exceptions, so they naturally
// return. If we hit this point it is a programmer error when implementing a
// test. Specifically, an entry for this OpTypeString is missing in the
- // static LongVectorOpTypeStringToOpMetaData array. Or something has been
+ // static OpTypeStringToOpMetaData array. Or something has been
// corrupted. Test execution is invalid at this point. Usin std::abort() keeps
// the compiler happy about no return path. And LOG_ERROR_FMT_THROW will still
// provide a useful error message via gtest logging on Unix systems.
std::abort();
}
-const OpTypeMetaData &
-getBinaryOpType(const std::wstring &OpTypeString) {
- return getLongVectorOpType(binaryOpTypeStringToOpMetaData,
- OpTypeString);
-}
-
-const OpTypeMetaData &
-getUnaryOpType(const std::wstring &OpTypeString) {
- return getLongVectorOpType(unaryOpTypeStringToOpMetaData,
- OpTypeString);
-}
-
-const OpTypeMetaData &
-getAsTypeOpType(const std::wstring &OpTypeString) {
- return getLongVectorOpType(asTypeOpTypeStringToOpMetaData,
- OpTypeString);
-}
-
-const OpTypeMetaData &
-getTrigonometricOpType(const std::wstring &OpTypeString) {
- return getLongVectorOpType(
- trigonometricOpTypeStringToOpMetaData, OpTypeString);
-}
-
// Helper to fill the test data from the shader buffer based on type. Convenient
// to be used when copying HLSL*_t types so we can use the underlying type.
template
@@ -185,8 +159,13 @@ void fillExpectedVector(VariantVector &ExpectedVector, size_t Count,
auto *TypedExpectedValues =
std::get_if>(&ExpectedVector);
- VERIFY_IS_NOT_NULL(TypedExpectedValues,
- L"Expected vector is not of the correct type.");
+ VERIFY_IS_NOT_NULL(
+ TypedExpectedValues,
+ L"Programmer error: Expected vector is not of the correct type.");
+
+ // A TestConfig may be reused for a different vector length. So this is a
+ // good time to make sure we clear the expected vector.
+ TypedExpectedValues->clear();
for (size_t Index = 0; Index < Count; ++Index)
TypedExpectedValues->push_back(ComputeFn(Index));
@@ -238,9 +217,7 @@ template std::string getHLSLTypeString() {
if (std::is_same_v)
return "uint64_t";
- std::string ErrStr("getHLSLTypeString() Unsupported type: ");
- ErrStr.append(typeid(DataTypeT).name());
- VERIFY_IS_TRUE(false, ErrStr.c_str());
+ LOG_ERROR_FMT_THROW(L"Unsupported type: %S", typeid(DataTypeT).name());
return "UnknownType";
}
@@ -369,10 +346,24 @@ TEST_F(OpTest, asTypeOpTest) {
dispatchTestByDataType(OpTypeMD, DataTypeIn, Handler);
}
-template
-void OpTest::dispatchTestByDataType(
- const OpTypeMetaData &OpTypeMd, std::wstring DataType,
- TableParameterHandler &Handler) {
+TEST_F(OpTest, unaryMathOpTest) {
+ WEX::TestExecution::SetVerifyOutput verifySettings(
+ WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
+
+ const int TableSize = sizeof(UnaryOpParameters) / sizeof(TableParameter);
+ TableParameterHandler Handler(UnaryOpParameters, TableSize);
+
+ std::wstring DataTypeIn(Handler.GetTableParamByName(L"DataType")->m_str);
+ std::wstring OpTypeString(Handler.GetTableParamByName(L"OpTypeEnum")->m_str);
+
+ auto OpTypeMD = getUnaryMathOpType(OpTypeString);
+ dispatchTestByDataType(OpTypeMD, DataTypeIn, Handler);
+}
+
+template
+void OpTest::dispatchTestByDataType(const OpTypeMetaData &OpTypeMd,
+ std::wstring DataType,
+ TableParameterHandler &Handler) {
using namespace WEX::Common;
if (DataType == L"bool")
@@ -396,8 +387,42 @@ void OpTest::dispatchTestByDataType(
else if (DataType == L"float64")
dispatchTestByVectorLength(OpTypeMd, Handler);
else
- VERIFY_FAIL(
- String().Format(L"DataType: %ls is not recognized.", DataType.c_str()));
+ LOG_ERROR_FMT_THROW(L"Unrecognized DataType: %ls for OpType: %ls.",
+ DataType.c_str(), OpTypeMd.OpTypeString.c_str());
+}
+
+template <>
+void OpTest::dispatchTestByDataType(
+ const OpTypeMetaData &OpTypeMd, std::wstring DataType,
+ TableParameterHandler &Handler) {
+ using namespace WEX::Common;
+
+ // Unary math ops don't support HLSLBool_t. If we included a dispatcher for
+ // them by allowing the generic dispatchTestByDataType then we would get
+ // compile errors for a bunch of the templated std lib functions we call to
+ // compute unary math ops. This is easier and cleaner than guarding against in
+ // at that point.
+ if (DataType == L"int16")
+ dispatchTestByVectorLength(OpTypeMd, Handler);
+ else if (DataType == L"int32")
+ dispatchTestByVectorLength(OpTypeMd, Handler);
+ else if (DataType == L"int64")
+ dispatchTestByVectorLength(OpTypeMd, Handler);
+ else if (DataType == L"uint16")
+ dispatchTestByVectorLength(OpTypeMd, Handler);
+ else if (DataType == L"uint32")
+ dispatchTestByVectorLength(OpTypeMd, Handler);
+ else if (DataType == L"uint64")
+ dispatchTestByVectorLength(OpTypeMd, Handler);
+ else if (DataType == L"float16")
+ dispatchTestByVectorLength(OpTypeMd, Handler);
+ else if (DataType == L"float32")
+ dispatchTestByVectorLength(OpTypeMd, Handler);
+ else if (DataType == L"float64")
+ dispatchTestByVectorLength(OpTypeMd, Handler);
+ else
+ LOG_ERROR_FMT_THROW(L"Invalid UnaryMathOpType DataType: %ls.",
+ DataType.c_str());
}
template <>
@@ -419,10 +444,9 @@ void OpTest::dispatchTestByDataType(
DataType.c_str());
}
-template
-void OpTest::dispatchTestByVectorLength(
- const OpTypeMetaData &OpTypeMd,
- TableParameterHandler &Handler) {
+template
+void OpTest::dispatchTestByVectorLength(const OpTypeMetaData &OpTypeMd,
+ TableParameterHandler &Handler) {
WEX::TestExecution::SetVerifyOutput verifySettings(
WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
@@ -637,26 +661,6 @@ void fillShaderBufferFromLongVectorData(
return;
}
-template
-std::string TestConfig::getHLSLInputTypeString() const {
- return getHLSLTypeString();
-}
-
-template
-std::string TestConfig::getHLSLOutputTypeString() const {
-
- // Normal case, output matches input type ( DataTypeT )
- if (auto *Vec = std::get_if>(&ExpectedVector))
- return getHLSLTypeString();
-
- // Non normal cases should be handled in a derived TestConfig class. i.e
- // TestConfigAsType::getHLSLOutputTypeString()
- LOG_ERROR_FMT_THROW(
- L"getHLSLOutputTypeString() called with an unsupported op type: %ls",
- OpTypeName.c_str());
- return "UnknownType";
-}
-
// Returns the compiler options string to be used for the shader compilation.
// Reference ShaderOpArith.xml and the 'LongVectorOp' shader source to see how
// the defines are used in the shader code.
@@ -687,10 +691,10 @@ std::string TestConfig::getCompilerOptionsString() const {
CompilerOptions << " -DIS_BINARY_VECTOR_OP=1";
CompilerOptions << " -DFUNC=";
- CompilerOptions << (Intrinsic ? *Intrinsic : "");
+ CompilerOptions << (Intrinsic ? *Intrinsic : " ");
} else { // Unary Op
CompilerOptions << " -DFUNC=";
- CompilerOptions << (Intrinsic ? *Intrinsic : "");
+ CompilerOptions << (Intrinsic ? *Intrinsic : " ");
// Not used for unary ops, but needs to be a " " for compilation of the
// shader after macro expansion.
CompilerOptions << " -DOPERAND2= ";
@@ -719,68 +723,67 @@ TestConfig::getInputValueSet(size_t ValueSetIndex) const {
else if (ValueSetIndex == 2)
InputValueSetName = InputValueSetName2;
else
- VERIFY_FAIL("Invalid ValueSetIndex");
+ LOG_ERROR_FMT_THROW(L"Invalid ValueSetIndex: %zu. Expected 1 or 2.",
+ ValueSetIndex);
return getInputValueSetByKey(InputValueSetName);
}
-// Public version of verifyOutput. Handles logic to dispatch to the correct
-// templated verifyOutput based on the expected output type.
+template
+std::string TestConfig::getHLSLOutputTypeString() const {
+ // std::visit allows us to dispatch a call to getHLSLTypeString() with the
+ // the current underlying element type of ExpectedVector.
+ return std::visit(
+ [](const auto &Vec) {
+ using ElementType = typename std::decay_t::value_type;
+ return getHLSLTypeString();
+ },
+ ExpectedVector);
+}
+
template
bool TestConfig::verifyOutput(
const std::shared_ptr &TestResult) {
- // First try the most common case where the output datatype matches the input
- // datatype (DataTypeT). std::get_if will return a null pointer if the variant
- // isn't holding a std::vector
- if (auto TypedExpectedValues =
- std::get_if>(&ExpectedVector)) {
- return verifyOutput(TestResult);
- }
- // If we get here, its likely a programmer error. DataTypeT is the DataType
- // passed in to the TestConfig when its created. The only time the
- // ExpectedVector has a diffferent type is for a handful of ops, such as
- // casting, where the output type is different from the input type. But proper
- // dispatching to verifyOutput with the correct data type is intended to be
- // handled by derived classes. Hence, we throw an error here. See callers of
- // the private version of verifyOutput for examples of proper usage.
- LOG_ERROR_FMT_THROW(
- L"verifyOutput() called with an unsupported expected vector type: %ls.",
- typeid(ExpectedVector).name());
- return false;
+ // std::visit allows us to dispatch a call to the private version of
+ // verifyOutput using a std::vector that matches the type currently held in
+ // ExpectedVector. This works because ExpectedVector is a std::variant of
+ // vector types, and the lambda receives the active type at runtime. It's
+ // important that the TestConfig instance has correctly assigned the expected
+ // output type to ExpectedVector. By default, this is std::vector,
+ // but ops like AsTypeOpType must override it. For example,
+ // AsTypeOpType_AsFloat16 sets ExpectedVector to std::vector.
+ return std::visit(
+ [this, &TestResult](const auto &Vec) {
+ using ElementType = typename std::decay_t::value_type;
+ return this->verifyOutput(TestResult, Vec);
+ },
+ ExpectedVector);
}
-// Private version of verifyOutput. Expected to be called internally when we've
-// resolved what the expected output type is.
+// Private version of verifyOutput. Called by the public version of verifyOutput
+// which resolves OutputDataTypeT based on the ExpectedVector type.
template
template
bool TestConfig::verifyOutput(
- const std::shared_ptr &TestResult) {
+ const std::shared_ptr &TestResult,
+ const std::vector &ExpectedVector) {
- if (auto TypedExpectedValues =
- std::get_if>(&ExpectedVector)) {
- MappedData ShaderOutData;
- TestResult->Test->GetReadBackData("OutputVector", &ShaderOutData);
+ WEX::Logging::Log::Comment(WEX::Common::String().Format(
+ L"verifyOutput with OpType: %ls ExpectedVector<%S>", OpTypeName.c_str(),
+ typeid(OutputDataTypeT).name()));
- // For most of the ops, the output vector size is the same as the input
- // vector size. But some, such as the AsUint_SplitDouble op, have an
- // output vector size that is double the input vector size.
- const size_t OutputVectorSize = (*TypedExpectedValues).size();
+ MappedData ShaderOutData;
+ TestResult->Test->GetReadBackData("OutputVector", &ShaderOutData);
- std::vector ActualValues;
- fillLongVectorDataFromShaderBuffer(ShaderOutData, ActualValues,
- OutputVectorSize);
+ const size_t OutputVectorSize = ExpectedVector.size();
- return doVectorsMatch(ActualValues, *TypedExpectedValues, Tolerance,
- ValidationType);
- }
+ std::vector ActualValues;
+ fillLongVectorDataFromShaderBuffer(ShaderOutData, ActualValues,
+ OutputVectorSize);
- // This is the private TestConfig::VerifyOutput. If this is hitting its most
- // likely a new test case for a new op type that is misconfigured.
- LOG_ERROR_FMT_THROW(L"PRIVATE verifyOutput() called with an unsupported "
- L"expected vector type: %ls.",
- typeid(ExpectedVector).name());
- return false;
+ return doVectorsMatch(ActualValues, ExpectedVector, Tolerance,
+ ValidationType);
}
// Generic computeExpectedValues for Unary ops. Derived classes override
@@ -851,7 +854,7 @@ TestConfigAsType::TestConfigAsType(
BasicOpType = BasicOpType_Binary;
break;
default:
- VERIFY_FAIL("Invalid AsTypeOpType");
+ LOG_ERROR_FMT_THROW(L"Unsupported AsTypeOpType: %ls", OpTypeName.c_str());
}
}
@@ -930,66 +933,6 @@ void TestConfigAsType::computeExpectedValues(
});
}
-template
-std::string TestConfigAsType::getHLSLOutputTypeString() const {
-
- switch (OpType) {
- case AsTypeOpType_AsFloat16:
- return getHLSLTypeString();
- case AsTypeOpType_AsFloat:
- return getHLSLTypeString();
- case AsTypeOpType_AsInt:
- return getHLSLTypeString();
- case AsTypeOpType_AsInt16:
- return getHLSLTypeString();
- case AsTypeOpType_AsUint:
- return getHLSLTypeString();
- case AsTypeOpType_AsUint_SplitDouble:
- return getHLSLTypeString();
- case AsTypeOpType_AsUint16:
- return getHLSLTypeString();
- case AsTypeOpType_AsDouble:
- return getHLSLTypeString();
- default:
- LOG_ERROR_FMT_THROW(
- L"getHLSLOutputTypeString() called with an unsupported op type: %ls",
- OpTypeName.c_str());
- return "UnknownType";
- }
-}
-
-// Override verifyOutput for AsTypeOpType as the output type for these ops
-// doesn't match the input type of the config. Calls a private templated version
-// of verifyOutput with the correct data type based on the op.
-template
-bool TestConfigAsType::verifyOutput(
- const std::shared_ptr &TestResult) {
-
- switch (OpType) {
- case AsTypeOpType_AsFloat:
- return TestConfig::verifyOutput(TestResult);
- case AsTypeOpType_AsFloat16:
- return TestConfig::verifyOutput(TestResult);
- case AsTypeOpType_AsInt:
- return TestConfig::verifyOutput(TestResult);
- case AsTypeOpType_AsInt16:
- return TestConfig::verifyOutput(TestResult);
- case AsTypeOpType_AsUint:
- return TestConfig::verifyOutput(TestResult);
- case AsTypeOpType_AsUint_SplitDouble:
- return TestConfig::verifyOutput(TestResult);
- case AsTypeOpType_AsUint16:
- return TestConfig::verifyOutput(TestResult);
- case AsTypeOpType_AsDouble:
- return TestConfig::verifyOutput(TestResult);
- default:
- LOG_ERROR_FMT_THROW(
- L"verifyOutput() called with an unsupported AsTypeOpType: %ls",
- OpTypeName.c_str());
- return false;
- }
-}
-
template
TestConfigTrigonometric::TestConfigTrigonometric(
const OpTypeMetaData &OpTypeMd)
@@ -1056,7 +999,7 @@ TestConfigUnary::TestConfigUnary(
SpecialDefines = " -DFUNC_INITIALIZE=1";
break;
default:
- VERIFY_FAIL("Invalid UnaryOpType");
+ LOG_ERROR_FMT_THROW(L"Unsupported UnaryOpType: %ls", OpTypeName.c_str());
}
}
@@ -1104,7 +1047,7 @@ TestConfigBinary::TestConfigBinary(
BasicOpType = BasicOpType_Binary;
break;
default:
- VERIFY_FAIL("Invalid BinaryOpType");
+ LOG_ERROR_FMT_THROW(L"Invalid BinaryOpType: %ls", OpTypeName.c_str());
}
}
@@ -1149,4 +1092,84 @@ TestConfigBinary::computeExpectedValue(const DataTypeT &A,
}
}
-}; // namespace LongVector
\ No newline at end of file
+template
+TestConfigUnaryMath::TestConfigUnaryMath(
+ const OpTypeMetaData &OpTypeMd)
+ : TestConfig(OpTypeMd), OpType(OpTypeMd.OpType) {
+
+ BasicOpType = BasicOpType_Unary;
+
+ if (isFloatingPointType()) {
+ Tolerance = 1;
+ ValidationType = ValidationType_Ulp;
+ }
+
+ if (OpType == UnaryMathOpType_Sign)
+ ExpectedVector = std::vector{};
+}
+
+template
+DataTypeT
+TestConfigUnaryMath::computeExpectedValue(const DataTypeT &A) const {
+
+ // A bunch of the std match functions here are wrapped in () to avoid
+ // collisions with the macro defitions for various functions in windows.h
+ switch (OpType) {
+ case UnaryMathOpType_Abs:
+ return abs(A);
+ case UnaryMathOpType_Ceil:
+ return (std::ceil)(A);
+ case UnaryMathOpType_Floor:
+ return (std::floor)(A);
+ case UnaryMathOpType_Trunc:
+ return (std::trunc)(A);
+ case UnaryMathOpType_Round:
+ return (std::round)(A);
+ case UnaryMathOpType_Frac:
+ // std::frac is not a standard C++ function, but we can implement it as
+ return A - DataTypeT((std::floor)(A));
+ case UnaryMathOpType_Sqrt:
+ return (std::sqrt)(A);
+ case UnaryMathOpType_Rsqrt:
+ // std::rsqrt is not a standard C++ function, but we can implement it as
+ return DataTypeT(1.0) / DataTypeT((std::sqrt)(A));
+ case UnaryMathOpType_Exp:
+ return (std::exp)(A);
+ case UnaryMathOpType_Exp2:
+ return (std::exp2)(A);
+ case UnaryMathOpType_Log:
+ return (std::log)(A);
+ case UnaryMathOpType_Log2:
+ return (std::log2)(A);
+ case UnaryMathOpType_Log10:
+ return (std::log10)(A);
+ case UnaryMathOpType_Rcp:
+ // std::.rcp is not a standard C++ function, but we can implement it as
+ return DataTypeT(1.0) / A;
+ default:
+ LOG_ERROR_FMT_THROW(L"computeExpectedValue(const DataTypeT &A)"
+ L"called on an unrecognized unary math op: %ls",
+ OpTypeName.c_str());
+ return DataTypeT();
+ }
+}
+
+template
+void TestConfigUnaryMath::computeExpectedValues(
+ const std::vector &InputVector1) {
+
+ if (OpType == UnaryMathOpType_Sign) {
+ fillExpectedVector(
+ ExpectedVector, InputVector1.size(),
+ // The sign function returns an int32_t value, so we handle here instead
+ // of in the templated computeExpectedValue function.
+ [&](size_t Index) { return sign(InputVector1[Index]); });
+ return;
+ }
+
+ fillExpectedVector(
+ ExpectedVector, InputVector1.size(),
+ [&](size_t Index) { return computeExpectedValue(InputVector1[Index]); });
+}
+
+}; // namespace LongVector
diff --git a/tools/clang/unittests/HLSLExec/LongVectors.h b/tools/clang/unittests/HLSLExec/LongVectors.h
index b7a37d09c1..6ba6961a0f 100644
--- a/tools/clang/unittests/HLSLExec/LongVectors.h
+++ b/tools/clang/unittests/HLSLExec/LongVectors.h
@@ -45,14 +45,6 @@ using VariantVector =
std::vector, std::vector,
std::vector>;
-// A helper struct to clear a VariantVector using std::visit.
-// Example usage: std::visit(ClearVariantVector{}, MyVariantVector);
-struct ClearVariantVector {
- template void operator()(std::vector &vec) const {
- vec.clear();
- }
-};
-
template
void fillShaderBufferFromLongVectorData(std::vector &ShaderBuffer,
const std::vector &TestData);
@@ -90,17 +82,16 @@ template std::string getHLSLTypeString();
// expansion. May be empty. See getCompilerOptionsString() in LongVector.cpp and
// 'LongVectorOp' entry ShaderOpArith.xml. Expands to things like '+', '-',
// '*', etc.
-template struct OpTypeMetaData {
+template struct OpTypeMetaData {
std::wstring OpTypeString;
- LongVectorOpTypeT OpType;
+ OpTypeT OpType;
std::optional Intrinsic = std::nullopt;
std::optional Operator = std::nullopt;
};
template
-const OpTypeMetaData &
-getLongVectorOpType(const OpTypeMetaData (&Values)[Length],
- const std::wstring &OpTypeString);
+const OpTypeMetaData &getOpType(const OpTypeMetaData (&Values)[Length],
+ const std::wstring &OpTypeString);
enum ValidationType {
ValidationType_Epsilon,
@@ -159,7 +150,9 @@ static_assert(_countof(binaryOpTypeStringToOpMetaData) ==
"add a new enum value?");
const OpTypeMetaData &
-getBinaryOpType(const std::wstring &OpTypeString);
+getBinaryOpType(const std::wstring &OpTypeString) {
+ return getOpType(binaryOpTypeStringToOpMetaData, OpTypeString);
+}
enum UnaryOpType { UnaryOpType_Initialize, UnaryOpType_EnumValueCount };
@@ -173,7 +166,9 @@ static_assert(_countof(unaryOpTypeStringToOpMetaData) ==
"a new enum value?");
const OpTypeMetaData &
-getUnaryOpType(const std::wstring &OpTypeString);
+getUnaryOpType(const std::wstring &OpTypeString) {
+ return getOpType(unaryOpTypeStringToOpMetaData, OpTypeString);
+}
enum AsTypeOpType {
AsTypeOpType_AsFloat,
@@ -205,7 +200,9 @@ static_assert(_countof(asTypeOpTypeStringToOpMetaData) ==
"a new enum value?");
const OpTypeMetaData &
-getAsTypeOpType(const std::wstring &OpTypeString);
+getAsTypeOpType(const std::wstring &OpTypeString) {
+ return getOpType(asTypeOpTypeStringToOpMetaData, OpTypeString);
+}
enum TrigonometricOpType {
TrigonometricOpType_Acos,
@@ -240,7 +237,59 @@ static_assert(
"a new enum value?");
const OpTypeMetaData &
-getTrigonometricOpType(const std::wstring &OpTypeString);
+getTrigonometricOpType(const std::wstring &OpTypeString) {
+ return getOpType(trigonometricOpTypeStringToOpMetaData,
+ OpTypeString);
+}
+
+enum UnaryMathOpType {
+ UnaryMathOpType_Abs,
+ UnaryMathOpType_Sign,
+ UnaryMathOpType_Ceil,
+ UnaryMathOpType_Floor,
+ UnaryMathOpType_Trunc,
+ UnaryMathOpType_Round,
+ UnaryMathOpType_Frac,
+ UnaryMathOpType_Sqrt,
+ UnaryMathOpType_Rsqrt,
+ UnaryMathOpType_Exp,
+ UnaryMathOpType_Exp2,
+ UnaryMathOpType_Log,
+ UnaryMathOpType_Log2,
+ UnaryMathOpType_Log10,
+ UnaryMathOpType_Rcp,
+ UnaryMathOpType_EnumValueCount
+};
+
+static const OpTypeMetaData
+ unaryMathOpTypeStringToOpMetaData[] = {
+ {L"UnaryMathOpType_Abs", UnaryMathOpType_Abs, "abs"},
+ {L"UnaryMathOpType_Sign", UnaryMathOpType_Sign, "sign"},
+ {L"UnaryMathOpType_Ceil", UnaryMathOpType_Ceil, "ceil"},
+ {L"UnaryMathOpType_Floor", UnaryMathOpType_Floor, "floor"},
+ {L"UnaryMathOpType_Trunc", UnaryMathOpType_Trunc, "trunc"},
+ {L"UnaryMathOpType_Round", UnaryMathOpType_Round, "round"},
+ {L"UnaryMathOpType_Frac", UnaryMathOpType_Frac, "frac"},
+ {L"UnaryMathOpType_Sqrt", UnaryMathOpType_Sqrt, "sqrt"},
+ {L"UnaryMathOpType_Rsqrt", UnaryMathOpType_Rsqrt, "rsqrt"},
+ {L"UnaryMathOpType_Exp", UnaryMathOpType_Exp, "exp"},
+ {L"UnaryMathOpType_Exp2", UnaryMathOpType_Exp2, "exp2"},
+ {L"UnaryMathOpType_Log", UnaryMathOpType_Log, "log"},
+ {L"UnaryMathOpType_Log2", UnaryMathOpType_Log2, "log2"},
+ {L"UnaryMathOpType_Log10", UnaryMathOpType_Log10, "log10"},
+ {L"UnaryMathOpType_Rcp", UnaryMathOpType_Rcp, "rcp"},
+};
+
+static_assert(_countof(unaryMathOpTypeStringToOpMetaData) ==
+ UnaryMathOpType_EnumValueCount,
+ "unaryMathOpTypeStringToOpMetaData size mismatch. Did you add "
+ "a new enum value?");
+
+const OpTypeMetaData &
+getUnaryMathOpType(const std::wstring &OpTypeString) {
+ return getOpType(unaryMathOpTypeStringToOpMetaData,
+ OpTypeString);
+}
template
std::vector getInputValueSetByKey(const std::wstring &Key,
@@ -248,7 +297,7 @@ std::vector getInputValueSetByKey(const std::wstring &Key,
if (LogKey)
WEX::Logging::Log::Comment(
WEX::Common::String().Format(L"Using Value Set Key: %s", Key.c_str()));
- return std::vector(LongVectorTestData::Data.at(Key));
+ return std::vector(TestData::Data.at(Key));
}
// The TAEF test class.
@@ -280,8 +329,13 @@ class OpTest {
L"Table:LongVectorOpTable.xml#AsTypeOpTable")
END_TEST_METHOD()
- template
- void dispatchTestByDataType(const OpTypeMetaData &OpTypeMD,
+ BEGIN_TEST_METHOD(unaryMathOpTest)
+ TEST_METHOD_PROPERTY(L"DataSource",
+ L"Table:LongVectorOpTable.xml#UnaryMathOpTable")
+ END_TEST_METHOD()
+
+ template
+ void dispatchTestByDataType(const OpTypeMetaData &OpTypeMD,
std::wstring DataType,
TableParameterHandler &Handler);
@@ -290,10 +344,14 @@ class OpTest {
dispatchTestByDataType(const OpTypeMetaData &OpTypeMD,
std::wstring DataType, TableParameterHandler &Handler);
- template
- void
- dispatchTestByVectorLength(const OpTypeMetaData &OpTypeMD,
- TableParameterHandler &Handler);
+ template <>
+ void dispatchTestByDataType(const OpTypeMetaData &OpTypeMD,
+ std::wstring DataType,
+ TableParameterHandler &Handler);
+
+ template
+ void dispatchTestByVectorLength(const OpTypeMetaData &OpTypeMD,
+ TableParameterHandler &Handler);
template
void testBaseMethod(std::unique_ptr> &TestConfig);
@@ -342,8 +400,10 @@ template class TestConfig {
bool isScalarOp() const { return BasicOpType == BasicOpType_ScalarBinary; }
// Helpers to get the hlsl type as a string for a given C++ type.
- std::string getHLSLInputTypeString() const;
- virtual std::string getHLSLOutputTypeString() const;
+ std::string getHLSLInputTypeString() const {
+ return getHLSLTypeString();
+ }
+ std::string getHLSLOutputTypeString() const;
virtual void
computeExpectedValues(const std::vector &InputVector1);
@@ -362,10 +422,6 @@ template class TestConfig {
}
void setLengthToTest(size_t LengthToTest) {
- // Make sure we clear the expected vector when setting a new length.
- // The TestConfig may be getting reused.
- std::visit(ClearVariantVector{}, ExpectedVector);
-
this->LengthToTest = LengthToTest;
}
@@ -386,12 +442,17 @@ template class TestConfig {
std::string getCompilerOptionsString() const;
- virtual bool
- verifyOutput(const std::shared_ptr &TestResult);
+ bool verifyOutput(const std::shared_ptr &TestResult);
private:
std::vector getInputValueSet(size_t ValueSetIndex) const;
+ // Templated version to be used when the output data type does not match the
+ // input data type.
+ template
+ bool verifyOutput(const std::shared_ptr &TestResult,
+ const std::vector &ExpectedVector);
+
// The input value sets are used to fill the shader buffer.
std::wstring InputValueSetName1 = L"DefaultInputValueSet1";
std::wstring InputValueSetName2 = L"DefaultInputValueSet2";
@@ -401,16 +462,11 @@ template class TestConfig {
protected:
// Prevent instances of TestConfig from being created directly. Want to force
// a derived class to be used for creation.
- template
- TestConfig(const OpTypeMetaData &OpTypeMd)
+ template
+ TestConfig(const OpTypeMetaData &OpTypeMd)
: OpTypeName(OpTypeMd.OpTypeString), Intrinsic(OpTypeMd.Intrinsic),
Operator(OpTypeMd.Operator) {}
- // Templated version to be used when the output data type does not match the
- // input data type.
- template
- bool verifyOutput(const std::shared_ptr &TestResult);
-
// The appropriate computeExpectedValue should be implemented in derived
// classes. Impelemented as virtual here to prevent requiring all derived
// classes from needing to implement. The OS builds disable RTTI, so using
@@ -457,9 +513,6 @@ class TestConfigAsType : public TestConfig {
void
computeExpectedValues(const std::vector &InputVector1,
const std::vector &InputVector2) override;
- std::string getHLSLOutputTypeString() const override;
- bool verifyOutput(
- const std::shared_ptr &TestResult) override;
private:
template
@@ -633,6 +686,32 @@ class TestConfigBinary : public TestConfig {
}
};
+template
+class TestConfigUnaryMath : public TestConfig {
+public:
+ TestConfigUnaryMath(const OpTypeMetaData &OpTypeMd);
+ DataTypeT computeExpectedValue(const DataTypeT &A) const override;
+ void
+ computeExpectedValues(const std::vector &InputVector1) override;
+
+private:
+ UnaryMathOpType OpType = UnaryMathOpType_EnumValueCount;
+
+ template int32_t sign(const DataTypeT &A) const {
+ // Return 1 for positive, -1 for negative, 0 for zero.
+ // Wrap comparison operands in DataTypeInT constructor to make sure
+ // we are comparing the same type.
+ return A > DataTypeT(0) ? 1 : A < DataTypeT(0) ? -1 : 0;
+ }
+
+ template DataTypeT abs(const DataTypeT &A) const {
+ if constexpr (std::is_unsigned::value)
+ return DataTypeT(A);
+ else
+ return (std::abs)(A);
+ }
+};
+
template
std::unique_ptr>
makeTestConfig(const OpTypeMetaData &OpTypeMetaData) {
@@ -657,6 +736,12 @@ makeTestConfig(const OpTypeMetaData &OpTypeMetaData) {
return std::make_unique>(OpTypeMetaData);
}
+template
+std::unique_ptr>
+makeTestConfig(const OpTypeMetaData &OpTypeMetaData) {
+ return std::make_unique>(OpTypeMetaData);
+}
+
}; // namespace LongVector
#endif // LONGVECTORS_H
From edf380ae4e3eb5e91209b5db28d780413c1d91a2 Mon Sep 17 00:00:00 2001
From: Alex Sepkowski <5620315+alsepkow@users.noreply.github.com>
Date: Mon, 11 Aug 2025 15:30:25 -0700
Subject: [PATCH 2/3] Fix location of endif
---
tools/clang/unittests/HLSLExec/LongVectorTestData.h | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/tools/clang/unittests/HLSLExec/LongVectorTestData.h b/tools/clang/unittests/HLSLExec/LongVectorTestData.h
index 733ce50fb4..60d677847d 100644
--- a/tools/clang/unittests/HLSLExec/LongVectorTestData.h
+++ b/tools/clang/unittests/HLSLExec/LongVectorTestData.h
@@ -300,6 +300,6 @@ template <> struct TestData {
};
};
-#endif // LONGVECTORTESTDATA_H
+}; // namespace LongVector
-}; // namespace LongVector
\ No newline at end of file
+#endif // LONGVECTORTESTDATA_H
From bbe8de08685e5f8c9cced61a45d0c276dd4e2af7 Mon Sep 17 00:00:00 2001
From: Alex Sepkowski <5620315+alsepkow@users.noreply.github.com>
Date: Mon, 11 Aug 2025 17:15:07 -0700
Subject: [PATCH 3/3] Switch cases with std::wstring
---
.../clang/unittests/HLSLExec/LongVectors.cpp | 65 +++++++++++++------
tools/clang/unittests/HLSLExec/LongVectors.h | 13 ++++
2 files changed, 57 insertions(+), 21 deletions(-)
diff --git a/tools/clang/unittests/HLSLExec/LongVectors.cpp b/tools/clang/unittests/HLSLExec/LongVectors.cpp
index 50f0939645..da5a927eb9 100644
--- a/tools/clang/unittests/HLSLExec/LongVectors.cpp
+++ b/tools/clang/unittests/HLSLExec/LongVectors.cpp
@@ -366,29 +366,41 @@ void OpTest::dispatchTestByDataType(const OpTypeMetaData &OpTypeMd,
TableParameterHandler &Handler) {
using namespace WEX::Common;
- if (DataType == L"bool")
+ switch (Hash_djb2a(DataType)) {
+ case Hash_djb2a(L"bool"):
dispatchTestByVectorLength(OpTypeMd, Handler);
- else if (DataType == L"int16")
+ break;
+ case Hash_djb2a(L"int16"):
dispatchTestByVectorLength(OpTypeMd, Handler);
- else if (DataType == L"int32")
+ break;
+ case Hash_djb2a(L"int32"):
dispatchTestByVectorLength(OpTypeMd, Handler);
- else if (DataType == L"int64")
+ break;
+ case Hash_djb2a(L"int64"):
dispatchTestByVectorLength(OpTypeMd, Handler);
- else if (DataType == L"uint16")
+ break;
+ case Hash_djb2a(L"uint16"):
dispatchTestByVectorLength(OpTypeMd, Handler);
- else if (DataType == L"uint32")
+ break;
+ case Hash_djb2a(L"uint32"):
dispatchTestByVectorLength(OpTypeMd, Handler);
- else if (DataType == L"uint64")
+ break;
+ case Hash_djb2a(L"uint64"):
dispatchTestByVectorLength(OpTypeMd, Handler);
- else if (DataType == L"float16")
+ break;
+ case Hash_djb2a(L"float16"):
dispatchTestByVectorLength(OpTypeMd, Handler);
- else if (DataType == L"float32")
+ break;
+ case Hash_djb2a(L"float32"):
dispatchTestByVectorLength(OpTypeMd, Handler);
- else if (DataType == L"float64")
+ break;
+ case Hash_djb2a(L"float64"):
dispatchTestByVectorLength(OpTypeMd, Handler);
- else
+ break;
+ default:
LOG_ERROR_FMT_THROW(L"Unrecognized DataType: %ls for OpType: %ls.",
DataType.c_str(), OpTypeMd.OpTypeString.c_str());
+ }
}
template <>
@@ -402,27 +414,38 @@ void OpTest::dispatchTestByDataType(
// compile errors for a bunch of the templated std lib functions we call to
// compute unary math ops. This is easier and cleaner than guarding against in
// at that point.
- if (DataType == L"int16")
+ switch (Hash_djb2a(DataType)) {
+ case Hash_djb2a(L"int16"):
dispatchTestByVectorLength(OpTypeMd, Handler);
- else if (DataType == L"int32")
+ break;
+ case Hash_djb2a(L"int32"):
dispatchTestByVectorLength