Skip to content

Commit d2740e6

Browse files
committed
Refactor reference counting in UR HIP adapter using new ur::RefCount class.
1 parent ebe9758 commit d2740e6

23 files changed

+87
-127
lines changed

unified-runtime/source/adapters/hip/adapter.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class ur_legacy_sink : public logger::Sink {
3838
// through UR entry points.
3939
// https://github.com/oneapi-src/unified-runtime/issues/1330
4040
ur_adapter_handle_t_::ur_adapter_handle_t_()
41-
: handle_base(),
41+
: handle_base(), RefCount(0),
4242
logger(logger::get_logger("hip",
4343
/*default_log_level*/ UR_LOGGER_LEVEL_ERROR)) {
4444
Platform = std::make_unique<ur_platform_handle_t_>();
@@ -58,7 +58,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
5858
std::call_once(InitFlag,
5959
[=]() { ur::hip::adapter = new ur_adapter_handle_t_; });
6060

61-
ur::hip::adapter->RefCount++;
61+
ur::hip::adapter->RefCount.retain();
6262
*phAdapters = ur::hip::adapter;
6363
}
6464
if (pNumAdapters) {
@@ -69,15 +69,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
6969
}
7070

7171
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
72-
if (--ur::hip::adapter->RefCount == 0) {
72+
if (ur::hip::adapter->RefCount.release()) {
7373
delete ur::hip::adapter;
7474
}
7575

7676
return UR_RESULT_SUCCESS;
7777
}
7878

7979
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
80-
ur::hip::adapter->RefCount++;
80+
ur::hip::adapter->RefCount.retain();
8181
return UR_RESULT_SUCCESS;
8282
}
8383

@@ -99,7 +99,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
9999
case UR_ADAPTER_INFO_BACKEND:
100100
return ReturnValue(UR_BACKEND_HIP);
101101
case UR_ADAPTER_INFO_REFERENCE_COUNT:
102-
return ReturnValue(ur::hip::adapter->RefCount.load());
102+
return ReturnValue(ur::hip::adapter->RefCount.getCount());
103103
case UR_ADAPTER_INFO_VERSION:
104104
return ReturnValue(uint32_t{1});
105105
default:

unified-runtime/source/adapters/hip/adapter.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
#ifndef UR_HIP_ADAPTER_HPP_INCLUDED
1212
#define UR_HIP_ADAPTER_HPP_INCLUDED
1313

14+
#include "common/ur_ref_count.hpp"
1415
#include "logger/ur_logger.hpp"
1516
#include "platform.hpp"
1617

1718
#include <atomic>
1819
#include <memory>
1920

2021
struct ur_adapter_handle_t_ : ur::hip::handle_base {
21-
std::atomic<uint32_t> RefCount = 0;
22+
ur::RefCount RefCount;
2223
logger::Logger &logger;
2324
std::unique_ptr<ur_platform_handle_t_> Platform;
2425
ur_adapter_handle_t_();

unified-runtime/source/adapters/hip/command_buffer.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ ur_exp_command_buffer_handle_t_::ur_exp_command_buffer_handle_t_(
2626
bool IsInOrder)
2727
: handle_base(), Context(hContext), Device(hDevice),
2828
IsUpdatable(IsUpdatable), IsInOrder(IsInOrder), HIPGraph{nullptr},
29-
HIPGraphExec{nullptr}, RefCount{1}, NextSyncPoint{0} {
29+
HIPGraphExec{nullptr}, NextSyncPoint{0} {
3030
urContextRetain(hContext);
3131
}
3232

@@ -266,13 +266,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(
266266

267267
UR_APIEXPORT ur_result_t UR_APICALL
268268
urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
269-
hCommandBuffer->incrementReferenceCount();
269+
hCommandBuffer->RefCount.retain();
270270
return UR_RESULT_SUCCESS;
271271
}
272272

273273
UR_APIEXPORT ur_result_t UR_APICALL
274274
urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
275-
if (hCommandBuffer->decrementReferenceCount() == 0) {
275+
if (hCommandBuffer->RefCount.release()) {
276276
if (hCommandBuffer->CurrentExecution) {
277277
UR_CHECK_ERROR(hCommandBuffer->CurrentExecution->wait());
278278
UR_CHECK_ERROR(urEventRelease(hCommandBuffer->CurrentExecution));
@@ -1055,7 +1055,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferGetInfoExp(
10551055

10561056
switch (propName) {
10571057
case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT:
1058-
return ReturnValue(hCommandBuffer->getReferenceCount());
1058+
return ReturnValue(hCommandBuffer->RefCount.getCount()());
10591059
case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: {
10601060
ur_exp_command_buffer_desc_t Descriptor{};
10611061
Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC;

unified-runtime/source/adapters/hip/command_buffer.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//
99
//===----------------------------------------------------------------------===//
1010

11+
#include "common/ur_ref_count.hpp"
1112
#include <ur/ur.hpp>
1213
#include <ur_api.h>
1314
#include <ur_print.hpp>
@@ -109,9 +110,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::hip::handle_base {
109110
registerSyncPoint(SyncPoint, std::move(HIPNode));
110111
return SyncPoint;
111112
}
112-
uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
113-
uint32_t decrementReferenceCount() noexcept { return --RefCount; }
114-
uint32_t getReferenceCount() const noexcept { return RefCount; }
115113

116114
// UR context associated with this command-buffer
117115
ur_context_handle_t Context;
@@ -125,9 +123,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::hip::handle_base {
125123
hipGraph_t HIPGraph;
126124
// HIP Graph Exec handle
127125
hipGraphExec_t HIPGraphExec = nullptr;
128-
// Atomic variable counting the number of reference to this command_buffer
129-
// using std::atomic prevents data race when incrementing/decrementing.
130-
std::atomic_uint32_t RefCount;
131126
// Track the event of the current graph execution. This extra synchronization
132127
// is needed because HIP (unlike CUDA) does not seem to synchronize with other
133128
// executions of the same graph during hipGraphLaunch and hipExecGraphDestroy.
@@ -142,4 +137,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::hip::handle_base {
142137
// Handles to individual commands in the command-buffer
143138
std::vector<std::unique_ptr<ur_exp_command_buffer_command_handle_t_>>
144139
CommandHandles;
140+
141+
ur::RefCount RefCount;
145142
};

unified-runtime/source/adapters/hip/context.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,
6868
return ReturnValue(hContext->getDevices().data(),
6969
hContext->getDevices().size());
7070
case UR_CONTEXT_INFO_REFERENCE_COUNT:
71-
return ReturnValue(hContext->getReferenceCount());
71+
return ReturnValue(hContext->RefCount.getCount()());
7272
case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT:
7373
// 2D USM memcpy is supported.
7474
return ReturnValue(true);
@@ -85,7 +85,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,
8585

8686
UR_APIEXPORT ur_result_t UR_APICALL
8787
urContextRelease(ur_context_handle_t hContext) {
88-
if (hContext->decrementReferenceCount() == 0) {
88+
if (hContext->RefCount.release()) {
8989
hContext->invokeExtendedDeleters();
9090
delete hContext;
9191
}
@@ -94,9 +94,9 @@ urContextRelease(ur_context_handle_t hContext) {
9494

9595
UR_APIEXPORT ur_result_t UR_APICALL
9696
urContextRetain(ur_context_handle_t hContext) {
97-
assert(hContext->getReferenceCount() > 0);
97+
assert(hContext->RefCount.getCount()() > 0);
9898

99-
hContext->incrementReferenceCount();
99+
hContext->RefCount.retain();
100100
return UR_RESULT_SUCCESS;
101101
}
102102

unified-runtime/source/adapters/hip/context.hpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "adapter.hpp"
1515
#include "common.hpp"
16+
#include "common/ur_ref_count.hpp"
1617
#include "device.hpp"
1718
#include "platform.hpp"
1819

@@ -88,10 +89,10 @@ struct ur_context_handle_t_ : ur::hip::handle_base {
8889

8990
std::vector<ur_device_handle_t> Devices;
9091

91-
std::atomic_uint32_t RefCount;
92+
ur::RefCount RefCount;
9293

9394
ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
94-
: handle_base(), Devices{Devs, Devs + NumDevices}, RefCount{1} {
95+
: handle_base(), Devices{Devs, Devs + NumDevices} {
9596
UR_CHECK_ERROR(urAdapterRetain(ur::hip::adapter));
9697
};
9798

@@ -125,12 +126,6 @@ struct ur_context_handle_t_ : ur::hip::handle_base {
125126
return std::distance(Devices.begin(), It);
126127
}
127128

128-
uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
129-
130-
uint32_t decrementReferenceCount() noexcept { return --RefCount; }
131-
132-
uint32_t getReferenceCount() const noexcept { return RefCount; }
133-
134129
void addPool(ur_usm_pool_handle_t Pool);
135130

136131
void removePool(ur_usm_pool_handle_t Pool);

unified-runtime/source/adapters/hip/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
478478
return ReturnValue("HIP");
479479
}
480480
case UR_DEVICE_INFO_REFERENCE_COUNT: {
481-
return ReturnValue(hDevice->getReferenceCount());
481+
return ReturnValue(hDevice->RefCount.getCount()());
482482
}
483483
case UR_DEVICE_INFO_VERSION: {
484484
std::stringstream S;

unified-runtime/source/adapters/hip/device.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#pragma once
1111

1212
#include "common.hpp"
13+
#include "common/ur_ref_count.hpp"
1314

1415
#include <ur/ur.hpp>
1516

@@ -22,7 +23,7 @@ struct ur_device_handle_t_ : ur::hip::handle_base {
2223
using native_type = hipDevice_t;
2324

2425
native_type HIPDevice;
25-
std::atomic_uint32_t RefCount;
26+
ur::RefCount RefCount;
2627
ur_platform_handle_t Platform;
2728
hipEvent_t EvBase; // HIP event used as base counter
2829
uint32_t DeviceIndex;
@@ -38,8 +39,8 @@ struct ur_device_handle_t_ : ur::hip::handle_base {
3839
public:
3940
ur_device_handle_t_(native_type HipDevice, hipEvent_t EvBase,
4041
ur_platform_handle_t Platform, uint32_t DeviceIndex)
41-
: handle_base(), HIPDevice(HipDevice), RefCount{1}, Platform(Platform),
42-
EvBase(EvBase), DeviceIndex(DeviceIndex) {
42+
: handle_base(), HIPDevice(HipDevice), Platform(Platform), EvBase(EvBase),
43+
DeviceIndex(DeviceIndex) {
4344

4445
UR_CHECK_ERROR(hipDeviceGetAttribute(
4546
&MaxWorkGroupSize, hipDeviceAttributeMaxThreadsPerBlock, HIPDevice));
@@ -99,8 +100,6 @@ struct ur_device_handle_t_ : ur::hip::handle_base {
99100

100101
native_type get() const noexcept { return HIPDevice; };
101102

102-
uint32_t getReferenceCount() const noexcept { return RefCount; }
103-
104103
ur_platform_handle_t getPlatform() const noexcept { return Platform; };
105104

106105
uint64_t getElapsedTime(hipEvent_t) const;

unified-runtime/source/adapters/hip/event.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent,
189189
case UR_EVENT_INFO_COMMAND_TYPE:
190190
return ReturnValue(hEvent->getCommandType());
191191
case UR_EVENT_INFO_REFERENCE_COUNT:
192-
return ReturnValue(hEvent->getReferenceCount());
192+
return ReturnValue(hEvent->RefCount.getCount());
193193
case UR_EVENT_INFO_COMMAND_EXECUTION_STATUS: {
194194
try {
195195
return ReturnValue(hEvent->getExecutionStatus());
@@ -245,7 +245,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventSetCallback(ur_event_handle_t,
245245
}
246246

247247
UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
248-
const auto RefCount = hEvent->incrementReferenceCount();
248+
const auto RefCount = hEvent->RefCount.retain();
249249

250250
if (RefCount == 0) {
251251
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
@@ -257,12 +257,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
257257
UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) {
258258
// double delete or someone is messing with the ref count.
259259
// either way, cannot safely proceed.
260-
if (hEvent->getReferenceCount() == 0) {
260+
if (hEvent->RefCount.getCount() == 0) {
261261
return UR_RESULT_ERROR_INVALID_EVENT;
262262
}
263263

264264
// decrement ref count. If it is 0, delete the event.
265-
if (hEvent->decrementReferenceCount() == 0) {
265+
if (hEvent->RefCount.release()) {
266266
std::unique_ptr<ur_event_handle_t_> event_ptr{hEvent};
267267
ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT;
268268
try {

unified-runtime/source/adapters/hip/event.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#pragma once
1111

1212
#include "common.hpp"
13+
#include "common/ur_ref_count.hpp"
1314
#include "queue.hpp"
1415

1516
/// UR Event mapping to hipEvent_t
@@ -57,16 +58,11 @@ struct ur_event_handle_t_ : ur::hip::handle_base {
5758
ur_context_handle_t getContext() const noexcept { return Context; };
5859
uint32_t getEventId() const noexcept { return EventId; }
5960

60-
// Reference counting.
61-
uint32_t getReferenceCount() const noexcept { return RefCount; }
62-
uint32_t incrementReferenceCount() { return ++RefCount; }
63-
uint32_t decrementReferenceCount() { return --RefCount; }
61+
ur::RefCount RefCount;
6462

6563
private:
6664
ur_command_t CommandType; // The type of command associated with event.
6765

68-
std::atomic_uint32_t RefCount{1}; // Event reference count.
69-
7066
bool HasOwnership{true}; // Signifies if event owns the native type.
7167
bool HasProfiling{false}; // Signifies if event has profiling information.
7268

0 commit comments

Comments
 (0)