diff --git a/cpp/include/cudf/context.hpp b/cpp/include/cudf/context.hpp index 60062e8ab05..30cfb0f2485 100644 --- a/cpp/include/cudf/context.hpp +++ b/cpp/include/cudf/context.hpp @@ -25,6 +25,8 @@ namespace CUDF_EXPORT cudf { /// @brief Flags for controlling initialization steps enum class init_flags : std::uint32_t { + /// @brief No initialization steps + NONE = 0, /// @brief Load the nvCOMP library during initialization LOAD_NVCOMP = 1 << 0, /// @brief Initialize the JIT program cache during initialization @@ -43,6 +45,27 @@ constexpr init_flags operator|(init_flags lhs, init_flags rhs) noexcept return static_cast(static_cast(lhs) | static_cast(rhs)); } +/// @brief Bitwise AND operator for init_flags +/// @param lhs The left-hand side of the operator +/// @param rhs The right-hand side of the operator +/// @return The result of the bitwise AND operation +constexpr init_flags operator&(init_flags lhs, init_flags rhs) noexcept +{ + using underlying_t = std::underlying_type_t; + return static_cast(static_cast(lhs) & static_cast(rhs)); +} + +/// @brief Bitwise NOT operator for init_flags +/// @param flags The flags to negate +/// @return The result of the bitwise NOT operation, only flipping bits that are part of +/// init_flags::ALL +constexpr init_flags operator~(init_flags flags) noexcept +{ + using underlying_t = std::underlying_type_t; + return static_cast(static_cast(init_flags::ALL) & + ~static_cast(flags)); +} + /// @brief Check if a flag is set /// @param flags The flags to check against /// @param flag The specific flag to check for @@ -54,7 +77,8 @@ constexpr bool has_flag(init_flags flags, init_flags flag) noexcept /// @brief Initialize the cudf global context /// @param flags Optional flags to control which initialization steps to perform. -/// @throws std::runtime_error if the context is already initialized +/// Can be called multiple times to initialize additional components. If all selected +/// steps are already performed, the call has no effect. void initialize(init_flags flags = init_flags::INIT_JIT_CACHE); /// @brief de-initialize the cudf global context diff --git a/cpp/src/runtime/context.cpp b/cpp/src/runtime/context.cpp index 288ac652b8d..0621bfafa78 100644 --- a/cpp/src/runtime/context.cpp +++ b/cpp/src/runtime/context.cpp @@ -29,17 +29,13 @@ namespace cudf { context::context(init_flags flags) : _program_cache{nullptr} { - if (has_flag(flags, init_flags::INIT_JIT_CACHE)) { - _program_cache = std::make_unique(); - } - - if (has_flag(flags, init_flags::LOAD_NVCOMP)) { io::detail::nvcomp::load_nvcomp_library(); } - auto dump_codegen_flag = getenv_or("LIBCUDF_JIT_DUMP_CODEGEN", std::string{"OFF"}); _dump_codegen = (dump_codegen_flag == "ON" || dump_codegen_flag == "1"); auto use_jit_flag = getenv_or("LIBCUDF_JIT_ENABLED", std::string{"OFF"}); _use_jit = (use_jit_flag == "ON" || use_jit_flag == "1"); + + initialize_components(flags); } jit::program_cache& context::program_cache() @@ -50,6 +46,20 @@ jit::program_cache& context::program_cache() bool context::dump_codegen() const { return _dump_codegen; } +void context::initialize_components(init_flags flags) +{ + // Only initialize components that haven't been initialized yet + auto const new_flags = flags & ~_initialized_flags; + + if (has_flag(new_flags, init_flags::INIT_JIT_CACHE)) { + _program_cache = std::make_unique(); + } + + if (has_flag(new_flags, init_flags::LOAD_NVCOMP)) { io::detail::nvcomp::load_nvcomp_library(); } + + _initialized_flags = _initialized_flags | new_flags; +} + bool context::use_jit() const { return _use_jit; } std::unique_ptr& get_context_ptr_ref() @@ -71,15 +81,15 @@ namespace CUDF_EXPORT cudf { void initialize(init_flags flags) { - CUDF_EXPECTS( - get_context_ptr_ref() == nullptr, "context is already initialized", std::runtime_error); - get_context_ptr_ref() = std::make_unique(flags); + auto& ctx = get_context_ptr_ref(); + if (ctx == nullptr) { + // First initialization - create the context + ctx = std::make_unique(flags); + } else { + // Context already exists - initialize additional components + ctx->initialize_components(flags); + } } -void deinitialize() -{ - CUDF_EXPECTS( - get_context_ptr_ref() != nullptr, "context has already been deinitialized", std::runtime_error); - get_context_ptr_ref().reset(); -} +void deinitialize() { get_context_ptr_ref().reset(); } } // namespace CUDF_EXPORT cudf diff --git a/cpp/src/runtime/context.hpp b/cpp/src/runtime/context.hpp index 6923c714a0a..b79ea45d3d7 100644 --- a/cpp/src/runtime/context.hpp +++ b/cpp/src/runtime/context.hpp @@ -33,8 +33,9 @@ class program_cache; class context { private: std::unique_ptr _program_cache; - bool _dump_codegen = false; - bool _use_jit = false; + init_flags _initialized_flags = init_flags::NONE; + bool _dump_codegen = false; + bool _use_jit = false; public: context(init_flags flags = init_flags::INIT_JIT_CACHE); @@ -48,6 +49,10 @@ class context { [[nodiscard]] bool dump_codegen() const; + /// @brief Initialize additional components based on the provided flags + /// @param flags The initialization flags to process + void initialize_components(init_flags flags); + [[nodiscard]] bool use_jit() const; }; diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index def973e7b9f..3eb0d5bf5a2 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -427,6 +427,7 @@ ConfigureTest( utilities_tests/column_debug_tests.cpp utilities_tests/column_utilities_tests.cpp utilities_tests/column_wrapper_tests.cpp + utilities_tests/context_tests.cpp utilities_tests/default_stream_tests.cpp utilities_tests/io_utilities_tests.cpp utilities_tests/lists_column_wrapper_tests.cpp diff --git a/cpp/tests/utilities_tests/context_tests.cpp b/cpp/tests/utilities_tests/context_tests.cpp new file mode 100644 index 00000000000..73c2472d5a0 --- /dev/null +++ b/cpp/tests/utilities_tests/context_tests.cpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include + +struct ContextTest : public cudf::test::BaseFixture { + ~ContextTest() override + { + try { + cudf::deinitialize(); + } catch (...) { + } + } +}; + +TEST_F(ContextTest, MultipleInitializeCalls) +{ + cudf::initialize(cudf::init_flags::INIT_JIT_CACHE); + + EXPECT_NO_THROW(cudf::initialize(cudf::init_flags::LOAD_NVCOMP)); + EXPECT_NO_THROW(cudf::initialize(cudf::init_flags::ALL)); +} + +TEST_F(ContextTest, InitializeAfterDeinitialize) +{ + cudf::initialize(cudf::init_flags::ALL); + cudf::deinitialize(); + + EXPECT_NO_THROW(cudf::initialize(cudf::init_flags::INIT_JIT_CACHE)); +} + +TEST_F(ContextTest, DeinitializeWithoutInitialize) { EXPECT_NO_THROW(cudf::deinitialize()); } + +TEST_F(ContextTest, MultipleDeinitializeCalls) +{ + cudf::initialize(cudf::init_flags::ALL); + cudf::deinitialize(); + + EXPECT_NO_THROW(cudf::deinitialize()); +}