Skip to content

[UR][HIP] Refactor reference counting in UR HIP adapter using new ur::RefCount class #19387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions unified-runtime/source/adapters/hip/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ur_legacy_sink : public logger::Sink {
// through UR entry points.
// https://github.com/oneapi-src/unified-runtime/issues/1330
ur_adapter_handle_t_::ur_adapter_handle_t_()
: handle_base(),
: handle_base(), RefCount(0),
logger(logger::get_logger("hip",
/*default_log_level*/ UR_LOGGER_LEVEL_ERROR)) {
Platform = std::make_unique<ur_platform_handle_t_>();
Expand All @@ -58,7 +58,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
std::call_once(InitFlag,
[=]() { ur::hip::adapter = new ur_adapter_handle_t_; });

ur::hip::adapter->RefCount++;
ur::hip::adapter->RefCount.retain();
*phAdapters = ur::hip::adapter;
}
if (pNumAdapters) {
Expand All @@ -69,15 +69,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
if (--ur::hip::adapter->RefCount == 0) {
if (ur::hip::adapter->RefCount.release()) {
delete ur::hip::adapter;
}

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
ur::hip::adapter->RefCount++;
ur::hip::adapter->RefCount.retain();
return UR_RESULT_SUCCESS;
}

Expand All @@ -99,7 +99,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
case UR_ADAPTER_INFO_BACKEND:
return ReturnValue(UR_BACKEND_HIP);
case UR_ADAPTER_INFO_REFERENCE_COUNT:
return ReturnValue(ur::hip::adapter->RefCount.load());
return ReturnValue(ur::hip::adapter->RefCount.getCount());
case UR_ADAPTER_INFO_VERSION:
return ReturnValue(uint32_t{1});
default:
Expand Down
3 changes: 2 additions & 1 deletion unified-runtime/source/adapters/hip/adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
#ifndef UR_HIP_ADAPTER_HPP_INCLUDED
#define UR_HIP_ADAPTER_HPP_INCLUDED

#include "common/ur_ref_count.hpp"
#include "logger/ur_logger.hpp"
#include "platform.hpp"

#include <atomic>
#include <memory>

struct ur_adapter_handle_t_ : ur::hip::handle_base {
std::atomic<uint32_t> RefCount = 0;
ur::RefCount RefCount;
logger::Logger &logger;
std::unique_ptr<ur_platform_handle_t_> Platform;
ur_adapter_handle_t_();
Expand Down
8 changes: 4 additions & 4 deletions unified-runtime/source/adapters/hip/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ ur_exp_command_buffer_handle_t_::ur_exp_command_buffer_handle_t_(
bool IsInOrder)
: handle_base(), Context(hContext), Device(hDevice),
IsUpdatable(IsUpdatable), IsInOrder(IsInOrder), HIPGraph{nullptr},
HIPGraphExec{nullptr}, RefCount{1}, NextSyncPoint{0} {
HIPGraphExec{nullptr}, NextSyncPoint{0} {
urContextRetain(hContext);
}

Expand Down Expand Up @@ -266,13 +266,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(

UR_APIEXPORT ur_result_t UR_APICALL
urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
hCommandBuffer->incrementReferenceCount();
hCommandBuffer->RefCount.retain();
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
if (hCommandBuffer->decrementReferenceCount() == 0) {
if (hCommandBuffer->RefCount.release()) {
if (hCommandBuffer->CurrentExecution) {
UR_CHECK_ERROR(hCommandBuffer->CurrentExecution->wait());
UR_CHECK_ERROR(urEventRelease(hCommandBuffer->CurrentExecution));
Expand Down Expand Up @@ -1055,7 +1055,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferGetInfoExp(

switch (propName) {
case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT:
return ReturnValue(hCommandBuffer->getReferenceCount());
return ReturnValue(hCommandBuffer->RefCount.getCount());
case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: {
ur_exp_command_buffer_desc_t Descriptor{};
Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC;
Expand Down
9 changes: 3 additions & 6 deletions unified-runtime/source/adapters/hip/command_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//
//===----------------------------------------------------------------------===//

#include "common/ur_ref_count.hpp"
#include <ur/ur.hpp>
#include <ur_api.h>
#include <ur_print.hpp>
Expand Down Expand Up @@ -109,9 +110,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::hip::handle_base {
registerSyncPoint(SyncPoint, std::move(HIPNode));
return SyncPoint;
}
uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
uint32_t decrementReferenceCount() noexcept { return --RefCount; }
uint32_t getReferenceCount() const noexcept { return RefCount; }

// UR context associated with this command-buffer
ur_context_handle_t Context;
Expand All @@ -125,9 +123,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::hip::handle_base {
hipGraph_t HIPGraph;
// HIP Graph Exec handle
hipGraphExec_t HIPGraphExec = nullptr;
// Atomic variable counting the number of reference to this command_buffer
// using std::atomic prevents data race when incrementing/decrementing.
std::atomic_uint32_t RefCount;
// Track the event of the current graph execution. This extra synchronization
// is needed because HIP (unlike CUDA) does not seem to synchronize with other
// executions of the same graph during hipGraphLaunch and hipExecGraphDestroy.
Expand All @@ -142,4 +137,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::hip::handle_base {
// Handles to individual commands in the command-buffer
std::vector<std::unique_ptr<ur_exp_command_buffer_command_handle_t_>>
CommandHandles;

ur::RefCount RefCount;
};
8 changes: 4 additions & 4 deletions unified-runtime/source/adapters/hip/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,
return ReturnValue(hContext->getDevices().data(),
hContext->getDevices().size());
case UR_CONTEXT_INFO_REFERENCE_COUNT:
return ReturnValue(hContext->getReferenceCount());
return ReturnValue(hContext->RefCount.getCount());
case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT:
// 2D USM memcpy is supported.
return ReturnValue(true);
Expand All @@ -85,7 +85,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,

UR_APIEXPORT ur_result_t UR_APICALL
urContextRelease(ur_context_handle_t hContext) {
if (hContext->decrementReferenceCount() == 0) {
if (hContext->RefCount.release()) {
hContext->invokeExtendedDeleters();
delete hContext;
}
Expand All @@ -94,9 +94,9 @@ urContextRelease(ur_context_handle_t hContext) {

UR_APIEXPORT ur_result_t UR_APICALL
urContextRetain(ur_context_handle_t hContext) {
assert(hContext->getReferenceCount() > 0);
assert(hContext->RefCount.getCount() > 0);

hContext->incrementReferenceCount();
hContext->RefCount.retain();
return UR_RESULT_SUCCESS;
}

Expand Down
11 changes: 3 additions & 8 deletions unified-runtime/source/adapters/hip/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "adapter.hpp"
#include "common.hpp"
#include "common/ur_ref_count.hpp"
#include "device.hpp"
#include "platform.hpp"

Expand Down Expand Up @@ -88,10 +89,10 @@ struct ur_context_handle_t_ : ur::hip::handle_base {

std::vector<ur_device_handle_t> Devices;

std::atomic_uint32_t RefCount;
ur::RefCount RefCount;

ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
: handle_base(), Devices{Devs, Devs + NumDevices}, RefCount{1} {
: handle_base(), Devices{Devs, Devs + NumDevices} {
UR_CHECK_ERROR(urAdapterRetain(ur::hip::adapter));
};

Expand Down Expand Up @@ -125,12 +126,6 @@ struct ur_context_handle_t_ : ur::hip::handle_base {
return std::distance(Devices.begin(), It);
}

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }

uint32_t decrementReferenceCount() noexcept { return --RefCount; }

uint32_t getReferenceCount() const noexcept { return RefCount; }

void addPool(ur_usm_pool_handle_t Pool);

void removePool(ur_usm_pool_handle_t Pool);
Expand Down
2 changes: 1 addition & 1 deletion unified-runtime/source/adapters/hip/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
return ReturnValue("HIP");
}
case UR_DEVICE_INFO_REFERENCE_COUNT: {
return ReturnValue(hDevice->getReferenceCount());
return ReturnValue(hDevice->RefCount.getCount());
}
case UR_DEVICE_INFO_VERSION: {
std::stringstream S;
Expand Down
10 changes: 5 additions & 5 deletions unified-runtime/source/adapters/hip/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#pragma once

#include "common.hpp"
#include "common/ur_ref_count.hpp"

#include <ur/ur.hpp>

Expand All @@ -22,7 +23,6 @@ struct ur_device_handle_t_ : ur::hip::handle_base {
using native_type = hipDevice_t;

native_type HIPDevice;
std::atomic_uint32_t RefCount;
ur_platform_handle_t Platform;
hipEvent_t EvBase; // HIP event used as base counter
uint32_t DeviceIndex;
Expand All @@ -38,8 +38,8 @@ struct ur_device_handle_t_ : ur::hip::handle_base {
public:
ur_device_handle_t_(native_type HipDevice, hipEvent_t EvBase,
ur_platform_handle_t Platform, uint32_t DeviceIndex)
: handle_base(), HIPDevice(HipDevice), RefCount{1}, Platform(Platform),
EvBase(EvBase), DeviceIndex(DeviceIndex) {
: handle_base(), HIPDevice(HipDevice), Platform(Platform), EvBase(EvBase),
DeviceIndex(DeviceIndex) {

UR_CHECK_ERROR(hipDeviceGetAttribute(
&MaxWorkGroupSize, hipDeviceAttributeMaxThreadsPerBlock, HIPDevice));
Expand Down Expand Up @@ -99,8 +99,6 @@ struct ur_device_handle_t_ : ur::hip::handle_base {

native_type get() const noexcept { return HIPDevice; };

uint32_t getReferenceCount() const noexcept { return RefCount; }

ur_platform_handle_t getPlatform() const noexcept { return Platform; };

uint64_t getElapsedTime(hipEvent_t) const;
Expand All @@ -126,6 +124,8 @@ struct ur_device_handle_t_ : ur::hip::handle_base {
};

bool supportsHardwareImages() const noexcept { return HardwareImageSupport; }

ur::RefCount RefCount;
};

int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute);
Expand Down
8 changes: 4 additions & 4 deletions unified-runtime/source/adapters/hip/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent,
case UR_EVENT_INFO_COMMAND_TYPE:
return ReturnValue(hEvent->getCommandType());
case UR_EVENT_INFO_REFERENCE_COUNT:
return ReturnValue(hEvent->getReferenceCount());
return ReturnValue(hEvent->RefCount.getCount());
case UR_EVENT_INFO_COMMAND_EXECUTION_STATUS: {
try {
return ReturnValue(hEvent->getExecutionStatus());
Expand Down Expand Up @@ -245,7 +245,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventSetCallback(ur_event_handle_t,
}

UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
const auto RefCount = hEvent->incrementReferenceCount();
const auto RefCount = hEvent->RefCount.retain();

if (RefCount == 0) {
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
Expand All @@ -257,12 +257,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) {
// double delete or someone is messing with the ref count.
// either way, cannot safely proceed.
if (hEvent->getReferenceCount() == 0) {
if (hEvent->RefCount.getCount() == 0) {
return UR_RESULT_ERROR_INVALID_EVENT;
}

// decrement ref count. If it is 0, delete the event.
if (hEvent->decrementReferenceCount() == 0) {
if (hEvent->RefCount.release()) {
std::unique_ptr<ur_event_handle_t_> event_ptr{hEvent};
ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT;
try {
Expand Down
8 changes: 2 additions & 6 deletions unified-runtime/source/adapters/hip/event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#pragma once

#include "common.hpp"
#include "common/ur_ref_count.hpp"
#include "queue.hpp"

/// UR Event mapping to hipEvent_t
Expand Down Expand Up @@ -57,16 +58,11 @@ struct ur_event_handle_t_ : ur::hip::handle_base {
ur_context_handle_t getContext() const noexcept { return Context; };
uint32_t getEventId() const noexcept { return EventId; }

// Reference counting.
uint32_t getReferenceCount() const noexcept { return RefCount; }
uint32_t incrementReferenceCount() { return ++RefCount; }
uint32_t decrementReferenceCount() { return --RefCount; }
ur::RefCount RefCount;

private:
ur_command_t CommandType; // The type of command associated with event.

std::atomic_uint32_t RefCount{1}; // Event reference count.

bool HasOwnership{true}; // Signifies if event owns the native type.
bool HasProfiling{false}; // Signifies if event has profiling information.

Expand Down
10 changes: 5 additions & 5 deletions unified-runtime/source/adapters/hip/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,20 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) {
UR_ASSERT(hKernel->getReferenceCount() > 0u, UR_RESULT_ERROR_INVALID_KERNEL);
UR_ASSERT(hKernel->RefCount.getCount() > 0u, UR_RESULT_ERROR_INVALID_KERNEL);

hKernel->incrementReferenceCount();
hKernel->RefCount.retain();
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urKernelRelease(ur_kernel_handle_t hKernel) {
// double delete or someone is messing with the ref count.
// either way, cannot safely proceed.
UR_ASSERT(hKernel->getReferenceCount() != 0, UR_RESULT_ERROR_INVALID_KERNEL);
UR_ASSERT(hKernel->RefCount.getCount() != 0, UR_RESULT_ERROR_INVALID_KERNEL);

// decrement ref count. If it is 0, delete the program.
if (hKernel->decrementReferenceCount() == 0) {
if (hKernel->RefCount.release()) {
// no internal cuda resources to clean up. Just delete it.
delete hKernel;
return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -201,7 +201,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel,
case UR_KERNEL_INFO_NUM_ARGS:
return ReturnValue(hKernel->getNumArgs());
case UR_KERNEL_INFO_REFERENCE_COUNT:
return ReturnValue(hKernel->getReferenceCount());
return ReturnValue(hKernel->RefCount.getCount());
case UR_KERNEL_INFO_CONTEXT:
return ReturnValue(hKernel->getContext());
case UR_KERNEL_INFO_PROGRAM:
Expand Down
11 changes: 3 additions & 8 deletions unified-runtime/source/adapters/hip/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===----------------------------------------------------------------------===//
#pragma once

#include "common/ur_ref_count.hpp"
#include <ur_api.h>

#include <array>
Expand Down Expand Up @@ -41,7 +42,7 @@ struct ur_kernel_handle_t_ : ur::hip::handle_base {
std::string Name;
ur_context_handle_t Context;
ur_program_handle_t Program;
std::atomic_uint32_t RefCount;
ur::RefCount RefCount;

static constexpr uint32_t ReqdThreadsPerBlockDimensions = 3u;
size_t ReqdThreadsPerBlock[ReqdThreadsPerBlockDimensions];
Expand Down Expand Up @@ -238,7 +239,7 @@ struct ur_kernel_handle_t_ : ur::hip::handle_base {
ur_context_handle_t Ctxt)
: handle_base(), Function{Func},
FunctionWithOffsetParam{FuncWithOffsetParam}, Name{Name}, Context{Ctxt},
Program{Program}, RefCount{1} {
Program{Program} {
urProgramRetain(Program);
urContextRetain(Context);

Expand Down Expand Up @@ -267,12 +268,6 @@ struct ur_kernel_handle_t_ : ur::hip::handle_base {

ur_program_handle_t getProgram() const noexcept { return Program; }

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }

uint32_t decrementReferenceCount() noexcept { return --RefCount; }

uint32_t getReferenceCount() const noexcept { return RefCount; }

native_type get() const noexcept { return Function; };

native_type getWithOffsetParameter() const noexcept {
Expand Down
Loading
Loading