diff --git a/sycl/source/detail/config.hpp b/sycl/source/detail/config.hpp index d08d42a238d99..2eb3716c76a05 100644 --- a/sycl/source/detail/config.hpp +++ b/sycl/source/detail/config.hpp @@ -199,35 +199,64 @@ template <> class SYCLConfig { private: public: static void GetSettings(size_t &MinFactor, size_t &GoodFactor, - size_t &MinRange) { - static const char *RoundParams = BaseT::getRawValue(); + size_t &MinRange, bool ForceUpdate = false) { + const char *RoundParams = BaseT::getRawValue(); if (RoundParams == nullptr) return; static bool ProcessedFactors = false; + static bool FactorsAreValid = false; static size_t MF; static size_t GF; static size_t MR; - if (!ProcessedFactors) { + if (!ProcessedFactors || ForceUpdate) { + auto GuardedStoi = [](size_t &val, const std::string &str) { + try { + int ParsedResult = std::stoi(str); + if (ParsedResult < 0) + return false; + val = ParsedResult; + return true; + // Ignore parsing exceptions, but throw on unexpected exceptions: + } catch (const std::invalid_argument &) { + } catch (const std::out_of_range &) { + } + return false; + }; + // Parse optional parameters of this form (all values required): // MinRound:PreferredRound:MinRange std::string Params(RoundParams); size_t Pos = Params.find(':'); - if (Pos != std::string::npos) { - MF = std::stoi(Params.substr(0, Pos)); + if (Pos != std::string::npos && GuardedStoi(MF, Params.substr(0, Pos)) && + MF > 0) { Params.erase(0, Pos + 1); Pos = Params.find(':'); - if (Pos != std::string::npos) { - GF = std::stoi(Params.substr(0, Pos)); + if (Pos != std::string::npos && + GuardedStoi(GF, Params.substr(0, Pos)) && GF > 0) { Params.erase(0, Pos + 1); - MR = std::stoi(Params); + // Factors are valid only if all parsed successfully: + FactorsAreValid = GuardedStoi(MR, Params); + // Note that MinRange = 0 is considered valid. } } - ProcessedFactors = true; + if (FactorsAreValid) { + ProcessedFactors = true; + } else { + std::cerr + << "WARNING: Invalid value passed for " + << "SYCL_PARALLEL_FOR_RANGE_ROUNDING_PARAMS (Expected format " + << "MinRound:PreferredRound:MinRange, where MinRound, " + "PreferredRound" + << " > 0, MinRange >= 0). Provided parameters will be ignored." + << std::endl; + } + } + if (FactorsAreValid) { + MinFactor = MF; + GoodFactor = GF; + MinRange = MR; } - MinFactor = MF; - GoodFactor = GF; - MinRange = MR; } }; diff --git a/sycl/unittests/config/ConfigTests.cpp b/sycl/unittests/config/ConfigTests.cpp index 0f990bc3c9847..2391bb608a61e 100644 --- a/sycl/unittests/config/ConfigTests.cpp +++ b/sycl/unittests/config/ConfigTests.cpp @@ -449,3 +449,69 @@ TEST(ConfigTests, CheckPersistentCacheEvictionThresholdTest) { OnDiskEvicType::reset(); TestConfig(0); } + +// SYCL_PARALLEL_FOR_RANGE_ROUNDING_PARAMS accepts ... +TEST(ConfigTests, CheckParallelForRangeRoundingParams) { + + // Lambda to set SYCL_PARALLEL_FOR_RANGE_ROUNDING_PARAMS. + auto SetRoundingParams = [](const char *value) { +#ifdef _WIN32 + _putenv_s("SYCL_PARALLEL_FOR_RANGE_ROUNDING_PARAMS", value); +#else + setenv("SYCL_PARALLEL_FOR_RANGE_ROUNDING_PARAMS", value, 1); +#endif + sycl::detail::readConfig(true); + }; + + // Lambda to assert test parameters are as expected. + auto AssertRoundingParams = [](size_t MF, size_t GF, size_t MR, + const char *errMsg, bool ForceUpdate = false) { + size_t ResultMF = 0, ResultGF = 0, ResultMR = 0; + SYCLConfig::GetSettings( + ResultMF, ResultGF, ResultMR, ForceUpdate); + EXPECT_EQ(MF, ResultMF) << errMsg; + EXPECT_EQ(GF, ResultGF) << errMsg; + EXPECT_EQ(MR, ResultMR) << errMsg; + }; + + // Lambda to test invalid input -- factors should remain unchanged. + auto TestBadInput = [&](const char *value, const char *errMsg) { + // Original factor values are stored as its own variable as size of size_t + // varies depending on system and architecture: + constexpr size_t MF = 1, GF = 2, MR = 3; + size_t TestMF = MF, TestGF = GF, TestMR = MR; + SetRoundingParams(value); + SYCLConfig::GetSettings( + TestMF, TestGF, TestMR, true); + EXPECT_EQ(TestMF, MF) << errMsg; + EXPECT_EQ(TestGF, GF) << errMsg; + EXPECT_EQ(TestMR, MR) << errMsg; + }; + + // Test malformed input: + constexpr char MalformedErr[] = + "Rounding parameters should be ignored on malformed input"; + TestBadInput("abc", MalformedErr); + TestBadInput("42", MalformedErr); + TestBadInput(":7", MalformedErr); + TestBadInput("7:", MalformedErr); + TestBadInput("1:2", MalformedErr); + TestBadInput("1:2:", MalformedErr); + TestBadInput("1:abc:3", MalformedErr); + + // Test well-formed input, but bad parameters: + constexpr char BadParamsErr[] = "Rounding parameters should be ignored if " + "parameters provided are invalid"; + TestBadInput("0:1:2", BadParamsErr); + TestBadInput("1:0:2", BadParamsErr); + TestBadInput("-1:2:3", BadParamsErr); + TestBadInput("1:2:31415926535897932384626433832795028841971", BadParamsErr); + + // Test valid values. + SetRoundingParams("8:16:32"); + AssertRoundingParams(8, 16, 32, + "Failed to read rounding parameters properly"); + SetRoundingParams("8:16:0"); + AssertRoundingParams(8, 16, 0, "0 is a valid value for MinRange", + /*ForceUpdate =*/true); +}