From d720bc675a6a0ba6e94302c597e2a4a7b7fbb98c Mon Sep 17 00:00:00 2001 From: Martin Morrison-Grant Date: Thu, 10 Jul 2025 12:06:24 +0100 Subject: [PATCH] Refactor reference counting in UR HIP adapter using new ur::RefCount class. --- .../source/adapters/hip/adapter.cpp | 10 +++++----- .../source/adapters/hip/adapter.hpp | 3 ++- .../source/adapters/hip/command_buffer.cpp | 8 ++++---- .../source/adapters/hip/command_buffer.hpp | 9 +++------ .../source/adapters/hip/context.cpp | 8 ++++---- .../source/adapters/hip/context.hpp | 11 +++------- .../source/adapters/hip/device.cpp | 2 +- .../source/adapters/hip/device.hpp | 10 +++++----- unified-runtime/source/adapters/hip/event.cpp | 8 ++++---- unified-runtime/source/adapters/hip/event.hpp | 8 ++------ .../source/adapters/hip/kernel.cpp | 10 +++++----- .../source/adapters/hip/kernel.hpp | 11 +++------- .../source/adapters/hip/memory.cpp | 8 ++++---- .../source/adapters/hip/memory.hpp | 20 ++++++++----------- .../source/adapters/hip/physical_mem.hpp | 11 +++------- .../source/adapters/hip/program.cpp | 10 +++++----- .../source/adapters/hip/program.hpp | 11 +++------- unified-runtime/source/adapters/hip/queue.cpp | 8 ++++---- .../source/adapters/hip/sampler.cpp | 8 ++++---- .../source/adapters/hip/sampler.hpp | 11 +++------- unified-runtime/source/adapters/hip/usm.cpp | 6 +++--- unified-runtime/source/adapters/hip/usm.hpp | 9 ++------- .../source/common/cuda-hip/stream_queue.hpp | 11 ++++++---- 23 files changed, 87 insertions(+), 124 deletions(-) diff --git a/unified-runtime/source/adapters/hip/adapter.cpp b/unified-runtime/source/adapters/hip/adapter.cpp index 225a743bc4455..826605f3a51c4 100644 --- a/unified-runtime/source/adapters/hip/adapter.cpp +++ b/unified-runtime/source/adapters/hip/adapter.cpp @@ -38,7 +38,7 @@ class ur_legacy_sink : public logger::Sink { // through UR entry points. // https://github.com/oneapi-src/unified-runtime/issues/1330 ur_adapter_handle_t_::ur_adapter_handle_t_() - : handle_base(), + : handle_base(), RefCount(0), logger(logger::get_logger("hip", /*default_log_level*/ UR_LOGGER_LEVEL_ERROR)) { Platform = std::make_unique(); @@ -58,7 +58,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( std::call_once(InitFlag, [=]() { ur::hip::adapter = new ur_adapter_handle_t_; }); - ur::hip::adapter->RefCount++; + ur::hip::adapter->RefCount.retain(); *phAdapters = ur::hip::adapter; } if (pNumAdapters) { @@ -69,7 +69,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { - if (--ur::hip::adapter->RefCount == 0) { + if (ur::hip::adapter->RefCount.release()) { delete ur::hip::adapter; } @@ -77,7 +77,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { - ur::hip::adapter->RefCount++; + ur::hip::adapter->RefCount.retain(); return UR_RESULT_SUCCESS; } @@ -99,7 +99,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_BACKEND_HIP); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(ur::hip::adapter->RefCount.load()); + return ReturnValue(ur::hip::adapter->RefCount.getCount()); case UR_ADAPTER_INFO_VERSION: return ReturnValue(uint32_t{1}); default: diff --git a/unified-runtime/source/adapters/hip/adapter.hpp b/unified-runtime/source/adapters/hip/adapter.hpp index ce054cfc27883..7033ca58cdcd2 100644 --- a/unified-runtime/source/adapters/hip/adapter.hpp +++ b/unified-runtime/source/adapters/hip/adapter.hpp @@ -11,6 +11,7 @@ #ifndef UR_HIP_ADAPTER_HPP_INCLUDED #define UR_HIP_ADAPTER_HPP_INCLUDED +#include "common/ur_ref_count.hpp" #include "logger/ur_logger.hpp" #include "platform.hpp" @@ -18,7 +19,7 @@ #include struct ur_adapter_handle_t_ : ur::hip::handle_base { - std::atomic RefCount = 0; + ur::RefCount RefCount; logger::Logger &logger; std::unique_ptr Platform; ur_adapter_handle_t_(); diff --git a/unified-runtime/source/adapters/hip/command_buffer.cpp b/unified-runtime/source/adapters/hip/command_buffer.cpp index 0514afdf71669..c8aac5b772c5d 100644 --- a/unified-runtime/source/adapters/hip/command_buffer.cpp +++ b/unified-runtime/source/adapters/hip/command_buffer.cpp @@ -26,7 +26,7 @@ ur_exp_command_buffer_handle_t_::ur_exp_command_buffer_handle_t_( bool IsInOrder) : handle_base(), Context(hContext), Device(hDevice), IsUpdatable(IsUpdatable), IsInOrder(IsInOrder), HIPGraph{nullptr}, - HIPGraphExec{nullptr}, RefCount{1}, NextSyncPoint{0} { + HIPGraphExec{nullptr}, NextSyncPoint{0} { urContextRetain(hContext); } @@ -266,13 +266,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp( UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) { - hCommandBuffer->incrementReferenceCount(); + hCommandBuffer->RefCount.retain(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) { - if (hCommandBuffer->decrementReferenceCount() == 0) { + if (hCommandBuffer->RefCount.release()) { if (hCommandBuffer->CurrentExecution) { UR_CHECK_ERROR(hCommandBuffer->CurrentExecution->wait()); UR_CHECK_ERROR(urEventRelease(hCommandBuffer->CurrentExecution)); @@ -1055,7 +1055,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferGetInfoExp( switch (propName) { case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT: - return ReturnValue(hCommandBuffer->getReferenceCount()); + return ReturnValue(hCommandBuffer->RefCount.getCount()); case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: { ur_exp_command_buffer_desc_t Descriptor{}; Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC; diff --git a/unified-runtime/source/adapters/hip/command_buffer.hpp b/unified-runtime/source/adapters/hip/command_buffer.hpp index 728d97719035b..bd6b085ac494e 100644 --- a/unified-runtime/source/adapters/hip/command_buffer.hpp +++ b/unified-runtime/source/adapters/hip/command_buffer.hpp @@ -8,6 +8,7 @@ // //===----------------------------------------------------------------------===// +#include "common/ur_ref_count.hpp" #include #include #include @@ -109,9 +110,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::hip::handle_base { registerSyncPoint(SyncPoint, std::move(HIPNode)); return SyncPoint; } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - uint32_t getReferenceCount() const noexcept { return RefCount; } // UR context associated with this command-buffer ur_context_handle_t Context; @@ -125,9 +123,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::hip::handle_base { hipGraph_t HIPGraph; // HIP Graph Exec handle hipGraphExec_t HIPGraphExec = nullptr; - // Atomic variable counting the number of reference to this command_buffer - // using std::atomic prevents data race when incrementing/decrementing. - std::atomic_uint32_t RefCount; // Track the event of the current graph execution. This extra synchronization // is needed because HIP (unlike CUDA) does not seem to synchronize with other // executions of the same graph during hipGraphLaunch and hipExecGraphDestroy. @@ -142,4 +137,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::hip::handle_base { // Handles to individual commands in the command-buffer std::vector> CommandHandles; + + ur::RefCount RefCount; }; diff --git a/unified-runtime/source/adapters/hip/context.cpp b/unified-runtime/source/adapters/hip/context.cpp index ae3c039a0abfe..0a4beb1c4c9ab 100644 --- a/unified-runtime/source/adapters/hip/context.cpp +++ b/unified-runtime/source/adapters/hip/context.cpp @@ -68,7 +68,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName, return ReturnValue(hContext->getDevices().data(), hContext->getDevices().size()); case UR_CONTEXT_INFO_REFERENCE_COUNT: - return ReturnValue(hContext->getReferenceCount()); + return ReturnValue(hContext->RefCount.getCount()); case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT: // 2D USM memcpy is supported. return ReturnValue(true); @@ -85,7 +85,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName, UR_APIEXPORT ur_result_t UR_APICALL urContextRelease(ur_context_handle_t hContext) { - if (hContext->decrementReferenceCount() == 0) { + if (hContext->RefCount.release()) { hContext->invokeExtendedDeleters(); delete hContext; } @@ -94,9 +94,9 @@ urContextRelease(ur_context_handle_t hContext) { UR_APIEXPORT ur_result_t UR_APICALL urContextRetain(ur_context_handle_t hContext) { - assert(hContext->getReferenceCount() > 0); + assert(hContext->RefCount.getCount() > 0); - hContext->incrementReferenceCount(); + hContext->RefCount.retain(); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/hip/context.hpp b/unified-runtime/source/adapters/hip/context.hpp index 3c011cec43a1b..120d5346d497f 100644 --- a/unified-runtime/source/adapters/hip/context.hpp +++ b/unified-runtime/source/adapters/hip/context.hpp @@ -13,6 +13,7 @@ #include "adapter.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "device.hpp" #include "platform.hpp" @@ -88,10 +89,10 @@ struct ur_context_handle_t_ : ur::hip::handle_base { std::vector Devices; - std::atomic_uint32_t RefCount; + ur::RefCount RefCount; ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices) - : handle_base(), Devices{Devs, Devs + NumDevices}, RefCount{1} { + : handle_base(), Devices{Devs, Devs + NumDevices} { UR_CHECK_ERROR(urAdapterRetain(ur::hip::adapter)); }; @@ -125,12 +126,6 @@ struct ur_context_handle_t_ : ur::hip::handle_base { return std::distance(Devices.begin(), It); } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } - void addPool(ur_usm_pool_handle_t Pool); void removePool(ur_usm_pool_handle_t Pool); diff --git a/unified-runtime/source/adapters/hip/device.cpp b/unified-runtime/source/adapters/hip/device.cpp index 4675acdf0b178..f8751031d4e0c 100644 --- a/unified-runtime/source/adapters/hip/device.cpp +++ b/unified-runtime/source/adapters/hip/device.cpp @@ -478,7 +478,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, return ReturnValue("HIP"); } case UR_DEVICE_INFO_REFERENCE_COUNT: { - return ReturnValue(hDevice->getReferenceCount()); + return ReturnValue(hDevice->RefCount.getCount()); } case UR_DEVICE_INFO_VERSION: { std::stringstream S; diff --git a/unified-runtime/source/adapters/hip/device.hpp b/unified-runtime/source/adapters/hip/device.hpp index f03fcdd8463b0..57cc4484b58fe 100644 --- a/unified-runtime/source/adapters/hip/device.hpp +++ b/unified-runtime/source/adapters/hip/device.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_count.hpp" #include @@ -22,7 +23,6 @@ struct ur_device_handle_t_ : ur::hip::handle_base { using native_type = hipDevice_t; native_type HIPDevice; - std::atomic_uint32_t RefCount; ur_platform_handle_t Platform; hipEvent_t EvBase; // HIP event used as base counter uint32_t DeviceIndex; @@ -38,8 +38,8 @@ struct ur_device_handle_t_ : ur::hip::handle_base { public: ur_device_handle_t_(native_type HipDevice, hipEvent_t EvBase, ur_platform_handle_t Platform, uint32_t DeviceIndex) - : handle_base(), HIPDevice(HipDevice), RefCount{1}, Platform(Platform), - EvBase(EvBase), DeviceIndex(DeviceIndex) { + : handle_base(), HIPDevice(HipDevice), Platform(Platform), EvBase(EvBase), + DeviceIndex(DeviceIndex) { UR_CHECK_ERROR(hipDeviceGetAttribute( &MaxWorkGroupSize, hipDeviceAttributeMaxThreadsPerBlock, HIPDevice)); @@ -99,8 +99,6 @@ struct ur_device_handle_t_ : ur::hip::handle_base { native_type get() const noexcept { return HIPDevice; }; - uint32_t getReferenceCount() const noexcept { return RefCount; } - ur_platform_handle_t getPlatform() const noexcept { return Platform; }; uint64_t getElapsedTime(hipEvent_t) const; @@ -126,6 +124,8 @@ struct ur_device_handle_t_ : ur::hip::handle_base { }; bool supportsHardwareImages() const noexcept { return HardwareImageSupport; } + + ur::RefCount RefCount; }; int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute); diff --git a/unified-runtime/source/adapters/hip/event.cpp b/unified-runtime/source/adapters/hip/event.cpp index f7da65d6fc993..b6a4070b8eca7 100644 --- a/unified-runtime/source/adapters/hip/event.cpp +++ b/unified-runtime/source/adapters/hip/event.cpp @@ -189,7 +189,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent, case UR_EVENT_INFO_COMMAND_TYPE: return ReturnValue(hEvent->getCommandType()); case UR_EVENT_INFO_REFERENCE_COUNT: - return ReturnValue(hEvent->getReferenceCount()); + return ReturnValue(hEvent->RefCount.getCount()); case UR_EVENT_INFO_COMMAND_EXECUTION_STATUS: { try { return ReturnValue(hEvent->getExecutionStatus()); @@ -245,7 +245,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventSetCallback(ur_event_handle_t, } UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) { - const auto RefCount = hEvent->incrementReferenceCount(); + const auto RefCount = hEvent->RefCount.retain(); if (RefCount == 0) { return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; @@ -257,12 +257,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) { UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) { // double delete or someone is messing with the ref count. // either way, cannot safely proceed. - if (hEvent->getReferenceCount() == 0) { + if (hEvent->RefCount.getCount() == 0) { return UR_RESULT_ERROR_INVALID_EVENT; } // decrement ref count. If it is 0, delete the event. - if (hEvent->decrementReferenceCount() == 0) { + if (hEvent->RefCount.release()) { std::unique_ptr event_ptr{hEvent}; ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT; try { diff --git a/unified-runtime/source/adapters/hip/event.hpp b/unified-runtime/source/adapters/hip/event.hpp index 63fe5ca273449..14df84ae85140 100644 --- a/unified-runtime/source/adapters/hip/event.hpp +++ b/unified-runtime/source/adapters/hip/event.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "queue.hpp" /// UR Event mapping to hipEvent_t @@ -57,16 +58,11 @@ struct ur_event_handle_t_ : ur::hip::handle_base { ur_context_handle_t getContext() const noexcept { return Context; }; uint32_t getEventId() const noexcept { return EventId; } - // Reference counting. - uint32_t getReferenceCount() const noexcept { return RefCount; } - uint32_t incrementReferenceCount() { return ++RefCount; } - uint32_t decrementReferenceCount() { return --RefCount; } + ur::RefCount RefCount; private: ur_command_t CommandType; // The type of command associated with event. - std::atomic_uint32_t RefCount{1}; // Event reference count. - bool HasOwnership{true}; // Signifies if event owns the native type. bool HasProfiling{false}; // Signifies if event has profiling information. diff --git a/unified-runtime/source/adapters/hip/kernel.cpp b/unified-runtime/source/adapters/hip/kernel.cpp index 39cddecd1efd5..980449020d125 100644 --- a/unified-runtime/source/adapters/hip/kernel.cpp +++ b/unified-runtime/source/adapters/hip/kernel.cpp @@ -127,9 +127,9 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, } UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) { - UR_ASSERT(hKernel->getReferenceCount() > 0u, UR_RESULT_ERROR_INVALID_KERNEL); + UR_ASSERT(hKernel->RefCount.getCount() > 0u, UR_RESULT_ERROR_INVALID_KERNEL); - hKernel->incrementReferenceCount(); + hKernel->RefCount.retain(); return UR_RESULT_SUCCESS; } @@ -137,10 +137,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease(ur_kernel_handle_t hKernel) { // double delete or someone is messing with the ref count. // either way, cannot safely proceed. - UR_ASSERT(hKernel->getReferenceCount() != 0, UR_RESULT_ERROR_INVALID_KERNEL); + UR_ASSERT(hKernel->RefCount.getCount() != 0, UR_RESULT_ERROR_INVALID_KERNEL); // decrement ref count. If it is 0, delete the program. - if (hKernel->decrementReferenceCount() == 0) { + if (hKernel->RefCount.release()) { // no internal cuda resources to clean up. Just delete it. delete hKernel; return UR_RESULT_SUCCESS; @@ -201,7 +201,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel, case UR_KERNEL_INFO_NUM_ARGS: return ReturnValue(hKernel->getNumArgs()); case UR_KERNEL_INFO_REFERENCE_COUNT: - return ReturnValue(hKernel->getReferenceCount()); + return ReturnValue(hKernel->RefCount.getCount()); case UR_KERNEL_INFO_CONTEXT: return ReturnValue(hKernel->getContext()); case UR_KERNEL_INFO_PROGRAM: diff --git a/unified-runtime/source/adapters/hip/kernel.hpp b/unified-runtime/source/adapters/hip/kernel.hpp index 569a2dc8c0bf4..792b2df012455 100644 --- a/unified-runtime/source/adapters/hip/kernel.hpp +++ b/unified-runtime/source/adapters/hip/kernel.hpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #pragma once +#include "common/ur_ref_count.hpp" #include #include @@ -41,7 +42,7 @@ struct ur_kernel_handle_t_ : ur::hip::handle_base { std::string Name; ur_context_handle_t Context; ur_program_handle_t Program; - std::atomic_uint32_t RefCount; + ur::RefCount RefCount; static constexpr uint32_t ReqdThreadsPerBlockDimensions = 3u; size_t ReqdThreadsPerBlock[ReqdThreadsPerBlockDimensions]; @@ -238,7 +239,7 @@ struct ur_kernel_handle_t_ : ur::hip::handle_base { ur_context_handle_t Ctxt) : handle_base(), Function{Func}, FunctionWithOffsetParam{FuncWithOffsetParam}, Name{Name}, Context{Ctxt}, - Program{Program}, RefCount{1} { + Program{Program} { urProgramRetain(Program); urContextRetain(Context); @@ -267,12 +268,6 @@ struct ur_kernel_handle_t_ : ur::hip::handle_base { ur_program_handle_t getProgram() const noexcept { return Program; } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } - native_type get() const noexcept { return Function; }; native_type getWithOffsetParameter() const noexcept { diff --git a/unified-runtime/source/adapters/hip/memory.cpp b/unified-runtime/source/adapters/hip/memory.cpp index d1e8789a41927..fdc765722eeb4 100644 --- a/unified-runtime/source/adapters/hip/memory.cpp +++ b/unified-runtime/source/adapters/hip/memory.cpp @@ -63,7 +63,7 @@ checkSupportedImageChannelType(ur_image_channel_type_t ImageChannelType) { UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) { try { // Do nothing if there are other references - if (hMem->decrementReferenceCount() > 0) { + if (!hMem->RefCount.release()) { return UR_RESULT_SUCCESS; } @@ -259,7 +259,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory, return ReturnValue(hMemory->getContext()); } case UR_MEM_INFO_REFERENCE_COUNT: { - return ReturnValue(hMemory->getReferenceCount()); + return ReturnValue(hMemory->RefCount.getCount()); } default: @@ -439,8 +439,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageGetInfo(ur_mem_handle_t hMemory, } UR_APIEXPORT ur_result_t UR_APICALL urMemRetain(ur_mem_handle_t hMem) { - UR_ASSERT(hMem->getReferenceCount() > 0, UR_RESULT_ERROR_INVALID_MEM_OBJECT); - hMem->incrementReferenceCount(); + UR_ASSERT(hMem->RefCount.getCount() > 0, UR_RESULT_ERROR_INVALID_MEM_OBJECT); + hMem->RefCount.retain(); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/hip/memory.hpp b/unified-runtime/source/adapters/hip/memory.hpp index b2367edf8f4b7..239e3949c740c 100644 --- a/unified-runtime/source/adapters/hip/memory.hpp +++ b/unified-runtime/source/adapters/hip/memory.hpp @@ -10,8 +10,10 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "context.hpp" #include "event.hpp" + #include #include #include @@ -317,7 +319,7 @@ struct ur_mem_handle_t_ : ur::hip::handle_base { ur_context Context; /// Reference counting of the handler - std::atomic_uint32_t RefCount; + ur::RefCount RefCount; // Original mem flags passed ur_mem_flags_t MemFlags; @@ -347,7 +349,7 @@ struct ur_mem_handle_t_ : ur::hip::handle_base { /// Constructs the UR mem handler for a non-typed allocation ("buffer") ur_mem_handle_t_(ur_context_handle_t Ctxt, ur_mem_flags_t MemFlags, BufferMem::AllocMode Mode, void *HostPtr, size_t Size) - : Context{Ctxt}, RefCount{1}, MemFlags{MemFlags}, + : Context{Ctxt}, MemFlags{MemFlags}, HaveMigratedToDeviceSinceLastWrite(Context->Devices.size(), false), Mem{std::in_place_type, Ctxt, this, Mode, HostPtr, Size} { urContextRetain(Context); @@ -355,9 +357,9 @@ struct ur_mem_handle_t_ : ur::hip::handle_base { // Subbuffer constructor ur_mem_handle_t_(ur_mem Parent, size_t SubBufferOffset) - : handle_base(), Context{Parent->Context}, RefCount{1}, - MemFlags{Parent->MemFlags}, HaveMigratedToDeviceSinceLastWrite( - Parent->Context->Devices.size(), false), + : handle_base(), Context{Parent->Context}, MemFlags{Parent->MemFlags}, + HaveMigratedToDeviceSinceLastWrite(Parent->Context->Devices.size(), + false), Mem{BufferMem{std::get(Parent->Mem)}} { auto &SubBuffer = std::get(Mem); SubBuffer.Parent = Parent; @@ -378,7 +380,7 @@ struct ur_mem_handle_t_ : ur::hip::handle_base { ur_mem_handle_t_(ur_context Ctxt, ur_mem_flags_t MemFlags, ur_image_format_t ImageFormat, ur_image_desc_t ImageDesc, void *HostPtr) - : Context{Ctxt}, RefCount{1}, MemFlags{MemFlags}, + : Context{Ctxt}, MemFlags{MemFlags}, HaveMigratedToDeviceSinceLastWrite(Context->Devices.size(), false), Mem{std::in_place_type, Ctxt, @@ -419,12 +421,6 @@ struct ur_mem_handle_t_ : ur::hip::handle_base { ur_context getContext() const noexcept { return Context; } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } - void setLastQueueWritingToMemObj(ur_queue_handle_t WritingQueue) { if (LastQueueWritingToMemObj != nullptr) { urQueueRelease(LastQueueWritingToMemObj); diff --git a/unified-runtime/source/adapters/hip/physical_mem.hpp b/unified-runtime/source/adapters/hip/physical_mem.hpp index 47342ae206510..5c0888525dc7d 100644 --- a/unified-runtime/source/adapters/hip/physical_mem.hpp +++ b/unified-runtime/source/adapters/hip/physical_mem.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "device.hpp" #include "platform.hpp" @@ -18,13 +19,7 @@ /// TODO: Implement. /// struct ur_physical_mem_handle_t_ : ur::hip::handle_base { - std::atomic_uint32_t RefCount; + ur::RefCount RefCount; - ur_physical_mem_handle_t_() : handle_base(), RefCount(1) {} - - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + ur_physical_mem_handle_t_() : handle_base() {} }; diff --git a/unified-runtime/source/adapters/hip/program.cpp b/unified-runtime/source/adapters/hip/program.cpp index 94e2f46440e96..94451730c66ca 100644 --- a/unified-runtime/source/adapters/hip/program.cpp +++ b/unified-runtime/source/adapters/hip/program.cpp @@ -385,7 +385,7 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName, switch (propName) { case UR_PROGRAM_INFO_REFERENCE_COUNT: - return ReturnValue(hProgram->getReferenceCount()); + return ReturnValue(hProgram->RefCount.getCount()); case UR_PROGRAM_INFO_CONTEXT: return ReturnValue(hProgram->Context); case UR_PROGRAM_INFO_NUM_DEVICES: @@ -418,8 +418,8 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName, UR_APIEXPORT ur_result_t UR_APICALL urProgramRetain(ur_program_handle_t hProgram) { - UR_ASSERT(hProgram->getReferenceCount() > 0, UR_RESULT_ERROR_INVALID_PROGRAM); - hProgram->incrementReferenceCount(); + UR_ASSERT(hProgram->RefCount.getCount() > 0, UR_RESULT_ERROR_INVALID_PROGRAM); + hProgram->RefCount.retain(); return UR_RESULT_SUCCESS; } @@ -430,11 +430,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramRelease(ur_program_handle_t hProgram) { // double delete or someone is messing with the ref count. // either way, cannot safely proceed. - UR_ASSERT(hProgram->getReferenceCount() != 0, + UR_ASSERT(hProgram->RefCount.getCount() != 0, UR_RESULT_ERROR_INVALID_PROGRAM); // decrement ref count. If it is 0, delete the program. - if (hProgram->decrementReferenceCount() == 0) { + if (hProgram->RefCount.release()) { std::unique_ptr ProgramPtr{hProgram}; try { ScopedDevice Active(hProgram->getDevice()); diff --git a/unified-runtime/source/adapters/hip/program.hpp b/unified-runtime/source/adapters/hip/program.hpp index c94818c6c43ab..b1b1040dcd6ec 100644 --- a/unified-runtime/source/adapters/hip/program.hpp +++ b/unified-runtime/source/adapters/hip/program.hpp @@ -14,6 +14,7 @@ #include #include +#include "common/ur_ref_count.hpp" #include "context.hpp" /// Implementation of UR Program on HIP Module object @@ -22,7 +23,7 @@ struct ur_program_handle_t_ : ur::hip::handle_base { native_type Module; const char *Binary; size_t BinarySizeInBytes; - std::atomic_uint32_t RefCount; + ur::RefCount RefCount; ur_context_handle_t Context; ur_device_handle_t Device; std::string ExecutableCache; @@ -49,7 +50,7 @@ struct ur_program_handle_t_ : ur::hip::handle_base { ur_program_handle_t_(ur_context_handle_t Ctxt, ur_device_handle_t Device) : handle_base(), Module{nullptr}, Binary{}, BinarySizeInBytes{0}, - RefCount{1}, Context{Ctxt}, Device{Device}, KernelReqdWorkGroupSizeMD{}, + Context{Ctxt}, Device{Device}, KernelReqdWorkGroupSizeMD{}, KernelReqdSubGroupSizeMD{} { urContextRetain(Context); @@ -71,12 +72,6 @@ struct ur_program_handle_t_ : ur::hip::handle_base { native_type get() const noexcept { return Module; }; - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } - ur_result_t getGlobalVariablePointer(const char *name, hipDeviceptr_t *DeviceGlobal, size_t *DeviceGlobalSize); diff --git a/unified-runtime/source/adapters/hip/queue.cpp b/unified-runtime/source/adapters/hip/queue.cpp index cca434b8f6ab0..0d5be1289e0ce 100644 --- a/unified-runtime/source/adapters/hip/queue.cpp +++ b/unified-runtime/source/adapters/hip/queue.cpp @@ -107,7 +107,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueGetInfo(ur_queue_handle_t hQueue, case UR_QUEUE_INFO_DEVICE: return ReturnValue(hQueue->Device); case UR_QUEUE_INFO_REFERENCE_COUNT: - return ReturnValue(hQueue->getReferenceCount()); + return ReturnValue(hQueue->RefCount.getCount()); case UR_QUEUE_INFO_FLAGS: return ReturnValue(hQueue->URFlags); case UR_QUEUE_INFO_EMPTY: { @@ -135,14 +135,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueGetInfo(ur_queue_handle_t hQueue, } UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain(ur_queue_handle_t hQueue) { - UR_ASSERT(hQueue->getReferenceCount() > 0, UR_RESULT_ERROR_INVALID_QUEUE); + UR_ASSERT(hQueue->RefCount.getCount() > 0, UR_RESULT_ERROR_INVALID_QUEUE); - hQueue->incrementReferenceCount(); + hQueue->RefCount.retain(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urQueueRelease(ur_queue_handle_t hQueue) { - if (hQueue->decrementReferenceCount() > 0) { + if (!hQueue->RefCount.release()) { return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/hip/sampler.cpp b/unified-runtime/source/adapters/hip/sampler.cpp index addcdb031402e..dfbecf7a1a7cb 100644 --- a/unified-runtime/source/adapters/hip/sampler.cpp +++ b/unified-runtime/source/adapters/hip/sampler.cpp @@ -44,7 +44,7 @@ ur_result_t urSamplerGetInfo(ur_sampler_handle_t hSampler, switch (propName) { case UR_SAMPLER_INFO_REFERENCE_COUNT: - return ReturnValue(hSampler->getReferenceCount()); + return ReturnValue(hSampler->RefCount.getCount()); case UR_SAMPLER_INFO_CONTEXT: return ReturnValue(hSampler->Context); case UR_SAMPLER_INFO_NORMALIZED_COORDS: { @@ -67,19 +67,19 @@ ur_result_t urSamplerGetInfo(ur_sampler_handle_t hSampler, } ur_result_t urSamplerRetain(ur_sampler_handle_t hSampler) { - hSampler->incrementReferenceCount(); + hSampler->RefCount.retain(); return UR_RESULT_SUCCESS; } ur_result_t urSamplerRelease(ur_sampler_handle_t hSampler) { // double delete or someone is messing with the ref count. // either way, cannot safely proceed. - if (hSampler->getReferenceCount() == 0) { + if (hSampler->RefCount.getCount() == 0) { return UR_RESULT_ERROR_INVALID_SAMPLER; } // decrement ref count. If it is 0, delete the sampler. - if (hSampler->decrementReferenceCount() == 0) { + if (hSampler->RefCount.release()) { delete hSampler; } diff --git a/unified-runtime/source/adapters/hip/sampler.hpp b/unified-runtime/source/adapters/hip/sampler.hpp index 1a1defea851ed..47b24617595df 100644 --- a/unified-runtime/source/adapters/hip/sampler.hpp +++ b/unified-runtime/source/adapters/hip/sampler.hpp @@ -10,6 +10,7 @@ #include +#include "common/ur_ref_count.hpp" #include "context.hpp" /// Implementation of samplers for HIP @@ -26,7 +27,7 @@ /// | 1 | filter mode /// | 0 | normalize coords struct ur_sampler_handle_t_ : ur::hip::handle_base { - std::atomic_uint32_t RefCount; + ur::RefCount RefCount; uint32_t Props; float MinMipmapLevelClamp; float MaxMipmapLevelClamp; @@ -34,13 +35,7 @@ struct ur_sampler_handle_t_ : ur::hip::handle_base { ur_context_handle_t Context; ur_sampler_handle_t_(ur_context_handle_t Context) - : handle_base(), RefCount(1), Props(0), Context(Context) {} - - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + : handle_base(), Props(0), Context(Context) {} ur_bool_t isNormalizedCoords() const noexcept { return static_cast(Props & 0b1); diff --git a/unified-runtime/source/adapters/hip/usm.cpp b/unified-runtime/source/adapters/hip/usm.cpp index 24c0872c9ace3..aa2b945c3fc7f 100644 --- a/unified-runtime/source/adapters/hip/usm.cpp +++ b/unified-runtime/source/adapters/hip/usm.cpp @@ -442,14 +442,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate( UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolRetain( /// [in] pointer to USM memory pool ur_usm_pool_handle_t Pool) { - Pool->incrementReferenceCount(); + Pool->RefCount.retain(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolRelease( /// [in] pointer to USM memory pool ur_usm_pool_handle_t Pool) { - if (Pool->decrementReferenceCount() > 0) { + if (!Pool->RefCount.release()) { return UR_RESULT_SUCCESS; } Pool->Context->removePool(Pool); @@ -472,7 +472,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolGetInfo( switch (propName) { case UR_USM_POOL_INFO_REFERENCE_COUNT: { - return ReturnValue(hPool->getReferenceCount()); + return ReturnValue(hPool->RefCount.getCount()); } case UR_USM_POOL_INFO_CONTEXT: { return ReturnValue(hPool->Context); diff --git a/unified-runtime/source/adapters/hip/usm.hpp b/unified-runtime/source/adapters/hip/usm.hpp index 43f8a11736a34..b770eac8555d0 100644 --- a/unified-runtime/source/adapters/hip/usm.hpp +++ b/unified-runtime/source/adapters/hip/usm.hpp @@ -9,6 +9,7 @@ //===-----------------------------------------------------------------===// #include "common.hpp" +#include "common/ur_ref_count.hpp" #include #include @@ -16,7 +17,7 @@ usm::DisjointPoolAllConfigs InitializeDisjointPoolConfig(); struct ur_usm_pool_handle_t_ : ur::hip::handle_base { - std::atomic_uint32_t RefCount = 1; + ur::RefCount RefCount; ur_context_handle_t Context = nullptr; @@ -30,12 +31,6 @@ struct ur_usm_pool_handle_t_ : ur::hip::handle_base { ur_usm_pool_handle_t_(ur_context_handle_t Context, ur_usm_pool_desc_t *PoolDesc); - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } - bool hasUMFPool(umf_memory_pool_t *umf_pool); }; diff --git a/unified-runtime/source/common/cuda-hip/stream_queue.hpp b/unified-runtime/source/common/cuda-hip/stream_queue.hpp index 0ead67e1d8729..d27df2d94a3f6 100644 --- a/unified-runtime/source/common/cuda-hip/stream_queue.hpp +++ b/unified-runtime/source/common/cuda-hip/stream_queue.hpp @@ -15,6 +15,8 @@ #include #include +#include "common/ur_ref_count.hpp" + using ur_stream_guard = std::unique_lock; /// Generic implementation of an out-of-order UR queue based on in-order @@ -44,7 +46,8 @@ struct stream_queue_t { std::vector TransferAppliedBarrier; ur_context_handle_t_ *Context; ur_device_handle_t_ *Device; - std::atomic_uint32_t RefCount{1}; + std::atomic_uint32_t RefCountOld{1}; + ur::RefCount RefCount; std::atomic_uint32_t EventCount{0}; std::atomic_uint32_t ComputeStreamIndex{0}; std::atomic_uint32_t TransferStreamIndex{0}; @@ -344,11 +347,11 @@ struct stream_queue_t { ur_context_handle_t_ *getContext() const { return Context; }; - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } + uint32_t incrementReferenceCount() noexcept { return ++RefCountOld; } - uint32_t decrementReferenceCount() noexcept { return --RefCount; } + uint32_t decrementReferenceCount() noexcept { return --RefCountOld; } - uint32_t getReferenceCount() const noexcept { return RefCount; } + uint32_t getReferenceCount() const noexcept { return RefCountOld; } uint32_t getNextEventId() noexcept { return ++EventCount; }