Skip to content

Commit bdf41ce

Browse files
committed
Refactor reference counting in UR HIP adapter using new ur::RefCount class.
1 parent ebe9758 commit bdf41ce

File tree

15 files changed

+50
-94
lines changed

15 files changed

+50
-94
lines changed

unified-runtime/source/adapters/hip/adapter.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class ur_legacy_sink : public logger::Sink {
3838
// through UR entry points.
3939
// https://github.com/oneapi-src/unified-runtime/issues/1330
4040
ur_adapter_handle_t_::ur_adapter_handle_t_()
41-
: handle_base(),
41+
: handle_base(), RefCount(0),
4242
logger(logger::get_logger("hip",
4343
/*default_log_level*/ UR_LOGGER_LEVEL_ERROR)) {
4444
Platform = std::make_unique<ur_platform_handle_t_>();
@@ -58,7 +58,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
5858
std::call_once(InitFlag,
5959
[=]() { ur::hip::adapter = new ur_adapter_handle_t_; });
6060

61-
ur::hip::adapter->RefCount++;
61+
ur::hip::adapter->RefCount.retain();
6262
*phAdapters = ur::hip::adapter;
6363
}
6464
if (pNumAdapters) {
@@ -69,15 +69,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
6969
}
7070

7171
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
72-
if (--ur::hip::adapter->RefCount == 0) {
72+
if (ur::hip::adapter->RefCount.release()) {
7373
delete ur::hip::adapter;
7474
}
7575

7676
return UR_RESULT_SUCCESS;
7777
}
7878

7979
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
80-
ur::hip::adapter->RefCount++;
80+
ur::hip::adapter->RefCount.retain();
8181
return UR_RESULT_SUCCESS;
8282
}
8383

@@ -99,7 +99,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
9999
case UR_ADAPTER_INFO_BACKEND:
100100
return ReturnValue(UR_BACKEND_HIP);
101101
case UR_ADAPTER_INFO_REFERENCE_COUNT:
102-
return ReturnValue(ur::hip::adapter->RefCount.load());
102+
return ReturnValue(ur::hip::adapter->RefCount.getCount());
103103
case UR_ADAPTER_INFO_VERSION:
104104
return ReturnValue(uint32_t{1});
105105
default:

unified-runtime/source/adapters/hip/adapter.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
#ifndef UR_HIP_ADAPTER_HPP_INCLUDED
1212
#define UR_HIP_ADAPTER_HPP_INCLUDED
1313

14+
#include "common/ur_ref_count.hpp"
1415
#include "logger/ur_logger.hpp"
1516
#include "platform.hpp"
1617

1718
#include <atomic>
1819
#include <memory>
1920

2021
struct ur_adapter_handle_t_ : ur::hip::handle_base {
21-
std::atomic<uint32_t> RefCount = 0;
22+
ur::RefCount RefCount;
2223
logger::Logger &logger;
2324
std::unique_ptr<ur_platform_handle_t_> Platform;
2425
ur_adapter_handle_t_();

unified-runtime/source/adapters/hip/command_buffer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ ur_exp_command_buffer_handle_t_::ur_exp_command_buffer_handle_t_(
2626
bool IsInOrder)
2727
: handle_base(), Context(hContext), Device(hDevice),
2828
IsUpdatable(IsUpdatable), IsInOrder(IsInOrder), HIPGraph{nullptr},
29-
HIPGraphExec{nullptr}, RefCount{1}, NextSyncPoint{0} {
29+
HIPGraphExec{nullptr}, NextSyncPoint{0} {
3030
urContextRetain(hContext);
3131
}
3232

unified-runtime/source/adapters/hip/command_buffer.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//
99
//===----------------------------------------------------------------------===//
1010

11+
#include "common/ur_ref_count.hpp"
1112
#include <ur/ur.hpp>
1213
#include <ur_api.h>
1314
#include <ur_print.hpp>
@@ -109,9 +110,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::hip::handle_base {
109110
registerSyncPoint(SyncPoint, std::move(HIPNode));
110111
return SyncPoint;
111112
}
112-
uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
113-
uint32_t decrementReferenceCount() noexcept { return --RefCount; }
114-
uint32_t getReferenceCount() const noexcept { return RefCount; }
115113

116114
// UR context associated with this command-buffer
117115
ur_context_handle_t Context;
@@ -125,9 +123,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::hip::handle_base {
125123
hipGraph_t HIPGraph;
126124
// HIP Graph Exec handle
127125
hipGraphExec_t HIPGraphExec = nullptr;
128-
// Atomic variable counting the number of reference to this command_buffer
129-
// using std::atomic prevents data race when incrementing/decrementing.
130-
std::atomic_uint32_t RefCount;
131126
// Track the event of the current graph execution. This extra synchronization
132127
// is needed because HIP (unlike CUDA) does not seem to synchronize with other
133128
// executions of the same graph during hipGraphLaunch and hipExecGraphDestroy.
@@ -142,4 +137,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::hip::handle_base {
142137
// Handles to individual commands in the command-buffer
143138
std::vector<std::unique_ptr<ur_exp_command_buffer_command_handle_t_>>
144139
CommandHandles;
140+
141+
ur::RefCount RefCount;
145142
};

unified-runtime/source/adapters/hip/context.hpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "adapter.hpp"
1515
#include "common.hpp"
16+
#include "common/ur_ref_count.hpp"
1617
#include "device.hpp"
1718
#include "platform.hpp"
1819

@@ -88,10 +89,10 @@ struct ur_context_handle_t_ : ur::hip::handle_base {
8889

8990
std::vector<ur_device_handle_t> Devices;
9091

91-
std::atomic_uint32_t RefCount;
92+
ur::RefCount RefCount;
9293

9394
ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
94-
: handle_base(), Devices{Devs, Devs + NumDevices}, RefCount{1} {
95+
: handle_base(), Devices{Devs, Devs + NumDevices} {
9596
UR_CHECK_ERROR(urAdapterRetain(ur::hip::adapter));
9697
};
9798

@@ -125,12 +126,6 @@ struct ur_context_handle_t_ : ur::hip::handle_base {
125126
return std::distance(Devices.begin(), It);
126127
}
127128

128-
uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
129-
130-
uint32_t decrementReferenceCount() noexcept { return --RefCount; }
131-
132-
uint32_t getReferenceCount() const noexcept { return RefCount; }
133-
134129
void addPool(ur_usm_pool_handle_t Pool);
135130

136131
void removePool(ur_usm_pool_handle_t Pool);

unified-runtime/source/adapters/hip/device.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#pragma once
1111

1212
#include "common.hpp"
13+
#include "common/ur_ref_count.hpp"
1314

1415
#include <ur/ur.hpp>
1516

@@ -22,7 +23,7 @@ struct ur_device_handle_t_ : ur::hip::handle_base {
2223
using native_type = hipDevice_t;
2324

2425
native_type HIPDevice;
25-
std::atomic_uint32_t RefCount;
26+
ur::RefCount RefCount;
2627
ur_platform_handle_t Platform;
2728
hipEvent_t EvBase; // HIP event used as base counter
2829
uint32_t DeviceIndex;
@@ -38,8 +39,8 @@ struct ur_device_handle_t_ : ur::hip::handle_base {
3839
public:
3940
ur_device_handle_t_(native_type HipDevice, hipEvent_t EvBase,
4041
ur_platform_handle_t Platform, uint32_t DeviceIndex)
41-
: handle_base(), HIPDevice(HipDevice), RefCount{1}, Platform(Platform),
42-
EvBase(EvBase), DeviceIndex(DeviceIndex) {
42+
: handle_base(), HIPDevice(HipDevice), Platform(Platform), EvBase(EvBase),
43+
DeviceIndex(DeviceIndex) {
4344

4445
UR_CHECK_ERROR(hipDeviceGetAttribute(
4546
&MaxWorkGroupSize, hipDeviceAttributeMaxThreadsPerBlock, HIPDevice));
@@ -99,8 +100,6 @@ struct ur_device_handle_t_ : ur::hip::handle_base {
99100

100101
native_type get() const noexcept { return HIPDevice; };
101102

102-
uint32_t getReferenceCount() const noexcept { return RefCount; }
103-
104103
ur_platform_handle_t getPlatform() const noexcept { return Platform; };
105104

106105
uint64_t getElapsedTime(hipEvent_t) const;

unified-runtime/source/adapters/hip/event.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent,
189189
case UR_EVENT_INFO_COMMAND_TYPE:
190190
return ReturnValue(hEvent->getCommandType());
191191
case UR_EVENT_INFO_REFERENCE_COUNT:
192-
return ReturnValue(hEvent->getReferenceCount());
192+
return ReturnValue(hEvent->RefCount.getCount());
193193
case UR_EVENT_INFO_COMMAND_EXECUTION_STATUS: {
194194
try {
195195
return ReturnValue(hEvent->getExecutionStatus());
@@ -245,7 +245,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventSetCallback(ur_event_handle_t,
245245
}
246246

247247
UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
248-
const auto RefCount = hEvent->incrementReferenceCount();
248+
const auto RefCount = hEvent->RefCount.retain();
249249

250250
if (RefCount == 0) {
251251
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
@@ -257,12 +257,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
257257
UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) {
258258
// double delete or someone is messing with the ref count.
259259
// either way, cannot safely proceed.
260-
if (hEvent->getReferenceCount() == 0) {
260+
if (hEvent->RefCount.getCount() == 0) {
261261
return UR_RESULT_ERROR_INVALID_EVENT;
262262
}
263263

264264
// decrement ref count. If it is 0, delete the event.
265-
if (hEvent->decrementReferenceCount() == 0) {
265+
if (hEvent->RefCount.release()) {
266266
std::unique_ptr<ur_event_handle_t_> event_ptr{hEvent};
267267
ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT;
268268
try {

unified-runtime/source/adapters/hip/event.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#pragma once
1111

1212
#include "common.hpp"
13+
#include "common/ur_ref_count.hpp"
1314
#include "queue.hpp"
1415

1516
/// UR Event mapping to hipEvent_t
@@ -57,16 +58,11 @@ struct ur_event_handle_t_ : ur::hip::handle_base {
5758
ur_context_handle_t getContext() const noexcept { return Context; };
5859
uint32_t getEventId() const noexcept { return EventId; }
5960

60-
// Reference counting.
61-
uint32_t getReferenceCount() const noexcept { return RefCount; }
62-
uint32_t incrementReferenceCount() { return ++RefCount; }
63-
uint32_t decrementReferenceCount() { return --RefCount; }
61+
ur::RefCount RefCount;
6462

6563
private:
6664
ur_command_t CommandType; // The type of command associated with event.
6765

68-
std::atomic_uint32_t RefCount{1}; // Event reference count.
69-
7066
bool HasOwnership{true}; // Signifies if event owns the native type.
7167
bool HasProfiling{false}; // Signifies if event has profiling information.
7268

unified-runtime/source/adapters/hip/kernel.hpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010
#pragma once
1111

12+
#include "common/ur_ref_count.hpp"
1213
#include <ur_api.h>
1314

1415
#include <array>
@@ -41,7 +42,7 @@ struct ur_kernel_handle_t_ : ur::hip::handle_base {
4142
std::string Name;
4243
ur_context_handle_t Context;
4344
ur_program_handle_t Program;
44-
std::atomic_uint32_t RefCount;
45+
ur::RefCount RefCount;
4546

4647
static constexpr uint32_t ReqdThreadsPerBlockDimensions = 3u;
4748
size_t ReqdThreadsPerBlock[ReqdThreadsPerBlockDimensions];
@@ -238,7 +239,7 @@ struct ur_kernel_handle_t_ : ur::hip::handle_base {
238239
ur_context_handle_t Ctxt)
239240
: handle_base(), Function{Func},
240241
FunctionWithOffsetParam{FuncWithOffsetParam}, Name{Name}, Context{Ctxt},
241-
Program{Program}, RefCount{1} {
242+
Program{Program} {
242243
urProgramRetain(Program);
243244
urContextRetain(Context);
244245

@@ -267,12 +268,6 @@ struct ur_kernel_handle_t_ : ur::hip::handle_base {
267268

268269
ur_program_handle_t getProgram() const noexcept { return Program; }
269270

270-
uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
271-
272-
uint32_t decrementReferenceCount() noexcept { return --RefCount; }
273-
274-
uint32_t getReferenceCount() const noexcept { return RefCount; }
275-
276271
native_type get() const noexcept { return Function; };
277272

278273
native_type getWithOffsetParameter() const noexcept {

unified-runtime/source/adapters/hip/memory.hpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
#pragma once
1111

1212
#include "common.hpp"
13+
#include "common/ur_ref_count.hpp"
1314
#include "context.hpp"
1415
#include "event.hpp"
16+
1517
#include <cassert>
1618
#include <memory>
1719
#include <unordered_map>
@@ -317,7 +319,7 @@ struct ur_mem_handle_t_ : ur::hip::handle_base {
317319
ur_context Context;
318320

319321
/// Reference counting of the handler
320-
std::atomic_uint32_t RefCount;
322+
ur::RefCount RefCount;
321323

322324
// Original mem flags passed
323325
ur_mem_flags_t MemFlags;
@@ -347,17 +349,17 @@ struct ur_mem_handle_t_ : ur::hip::handle_base {
347349
/// Constructs the UR mem handler for a non-typed allocation ("buffer")
348350
ur_mem_handle_t_(ur_context_handle_t Ctxt, ur_mem_flags_t MemFlags,
349351
BufferMem::AllocMode Mode, void *HostPtr, size_t Size)
350-
: Context{Ctxt}, RefCount{1}, MemFlags{MemFlags},
352+
: Context{Ctxt}, MemFlags{MemFlags},
351353
HaveMigratedToDeviceSinceLastWrite(Context->Devices.size(), false),
352354
Mem{std::in_place_type<BufferMem>, Ctxt, this, Mode, HostPtr, Size} {
353355
urContextRetain(Context);
354356
}
355357

356358
// Subbuffer constructor
357359
ur_mem_handle_t_(ur_mem Parent, size_t SubBufferOffset)
358-
: handle_base(), Context{Parent->Context}, RefCount{1},
359-
MemFlags{Parent->MemFlags}, HaveMigratedToDeviceSinceLastWrite(
360-
Parent->Context->Devices.size(), false),
360+
: handle_base(), Context{Parent->Context}, MemFlags{Parent->MemFlags},
361+
HaveMigratedToDeviceSinceLastWrite(Parent->Context->Devices.size(),
362+
false),
361363
Mem{BufferMem{std::get<BufferMem>(Parent->Mem)}} {
362364
auto &SubBuffer = std::get<BufferMem>(Mem);
363365
SubBuffer.Parent = Parent;
@@ -378,7 +380,7 @@ struct ur_mem_handle_t_ : ur::hip::handle_base {
378380
ur_mem_handle_t_(ur_context Ctxt, ur_mem_flags_t MemFlags,
379381
ur_image_format_t ImageFormat, ur_image_desc_t ImageDesc,
380382
void *HostPtr)
381-
: Context{Ctxt}, RefCount{1}, MemFlags{MemFlags},
383+
: Context{Ctxt}, MemFlags{MemFlags},
382384
HaveMigratedToDeviceSinceLastWrite(Context->Devices.size(), false),
383385
Mem{std::in_place_type<SurfaceMem>,
384386
Ctxt,
@@ -419,12 +421,6 @@ struct ur_mem_handle_t_ : ur::hip::handle_base {
419421

420422
ur_context getContext() const noexcept { return Context; }
421423

422-
uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
423-
424-
uint32_t decrementReferenceCount() noexcept { return --RefCount; }
425-
426-
uint32_t getReferenceCount() const noexcept { return RefCount; }
427-
428424
void setLastQueueWritingToMemObj(ur_queue_handle_t WritingQueue) {
429425
if (LastQueueWritingToMemObj != nullptr) {
430426
urQueueRelease(LastQueueWritingToMemObj);

0 commit comments

Comments
 (0)