Skip to content

Commit c5608a4

Browse files
committed
Initial support for multi-device contexts
1 parent 0260118 commit c5608a4

File tree

12 files changed

+250
-53
lines changed

12 files changed

+250
-53
lines changed

unified-runtime/source/adapters/offload/context.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,16 @@
1414
UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
1515
uint32_t DeviceCount, const ur_device_handle_t *phDevices,
1616
const ur_context_properties_t *, ur_context_handle_t *phContext) {
17-
if (DeviceCount > 1) {
18-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
17+
18+
// For multi-device contexts, all devices must have the same platform.
19+
ur_device_handle_t FirstDevice = *phDevices;
20+
for (uint32_t i = 1; i < DeviceCount; i++) {
21+
if (phDevices[i]->Platform != FirstDevice->Platform) {
22+
return UR_RESULT_ERROR_INVALID_DEVICE;
23+
}
1924
}
2025

21-
auto Ctx = new ur_context_handle_t_(*phDevices);
26+
auto Ctx = new ur_context_handle_t_(phDevices, DeviceCount);
2227
*phContext = Ctx;
2328
return UR_RESULT_SUCCESS;
2429
}
@@ -30,9 +35,9 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,
3035

3136
switch (propName) {
3237
case UR_CONTEXT_INFO_NUM_DEVICES:
33-
return ReturnValue(uint32_t{1});
38+
return ReturnValue(hContext->Devices.size());
3439
case UR_CONTEXT_INFO_DEVICES:
35-
return ReturnValue(&hContext->Device, 1);
40+
return ReturnValue(hContext->Devices.data(), hContext->Devices.size());
3641
case UR_CONTEXT_INFO_REFERENCE_COUNT:
3742
return ReturnValue(hContext->RefCount.load());
3843
case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT:

unified-runtime/source/adapters/offload/context.hpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,28 @@
1818
#include <ur_api.h>
1919

2020
struct ur_context_handle_t_ : RefCounted {
21-
ur_context_handle_t_(ur_device_handle_t hDevice) : Device{hDevice} {
22-
urDeviceRetain(Device);
21+
ur_context_handle_t_(const ur_device_handle_t *Devs, size_t NumDevices)
22+
: Devices{Devs, Devs + NumDevices} {
23+
for (auto Device : Devices) {
24+
urDeviceRetain(Device);
25+
}
2326
}
24-
~ur_context_handle_t_() { urDeviceRelease(Device); }
27+
~ur_context_handle_t_() {
28+
for (auto Device : Devices) {
29+
urDeviceRelease(Device);
30+
}
31+
}
32+
33+
std::vector<ur_device_handle_t> Devices;
2534

26-
ur_device_handle_t Device;
27-
std::unordered_map<void *, ol_alloc_type_t> AllocTypeMap;
35+
// Gets the index of the device relative to other devices in the context
36+
size_t getDeviceIndex(ur_device_handle_t hDevice) {
37+
auto It = std::find(Devices.begin(), Devices.end(), hDevice);
38+
assert(It != Devices.end());
39+
return std::distance(Devices.begin(), It);
40+
}
41+
42+
bool containsDevice(ur_device_handle_t Device) {
43+
return std::find(Devices.begin(), Devices.end(), Device) != Devices.end();
44+
}
2845
};

unified-runtime/source/adapters/offload/enqueue.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
6767
LaunchArgs.GroupSize.z = GroupSize[2];
6868
LaunchArgs.DynSharedMemory = 0;
6969

70+
// Prepare memobj arguments
71+
for (auto &Arg : hKernel->Args.MemObjArgs) {
72+
Arg.Mem->enqueueMigrateMemoryToDeviceIfNeeded(hQueue->Device,
73+
hQueue->OffloadQueue);
74+
if (Arg.AccessFlags & (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
75+
Arg.Mem->setLastQueueWritingToMemObj(hQueue);
76+
}
77+
}
78+
7079
ol_event_handle_t EventOut;
7180
OL_RETURN_ON_ERR(
7281
olLaunchKernel(hQueue->OffloadQueue, hQueue->OffloadDevice,
@@ -105,8 +114,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
105114

106115
ol_event_handle_t EventOut = nullptr;
107116

108-
char *DevPtr =
109-
reinterpret_cast<char *>(std::get<BufferMem>(hBuffer->Mem).Ptr);
117+
// Note that this entry point may be called on a queue that may not be the
118+
// last queue to write to the MemBuffer, meaning we must perform the copy
119+
// from a different device
120+
// TODO: Evaluate whether this is better than just migrating the memory to the
121+
// correct device and then doing the read.
122+
if (hBuffer->LastQueueWritingToMemObj &&
123+
hBuffer->LastQueueWritingToMemObj->Device != hQueue->Device) {
124+
hQueue = hBuffer->LastQueueWritingToMemObj;
125+
}
126+
127+
char *DevPtr = reinterpret_cast<char *>(
128+
std::get<BufferMem>(hBuffer->Mem).getPtr(hQueue->Device));
110129

111130
OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, pDst, Adapter->HostDevice,
112131
DevPtr + offset, hQueue->OffloadDevice, size,
@@ -137,8 +156,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
137156

138157
ol_event_handle_t EventOut = nullptr;
139158

140-
char *DevPtr =
141-
reinterpret_cast<char *>(std::get<BufferMem>(hBuffer->Mem).Ptr);
159+
char *DevPtr = reinterpret_cast<char *>(
160+
std::get<BufferMem>(hBuffer->Mem).getPtr(hQueue->Device));
142161

143162
OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, DevPtr + offset,
144163
hQueue->OffloadDevice, pSrc, Adapter->HostDevice,

unified-runtime/source/adapters/offload/kernel.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ urKernelCreate(ur_program_handle_t hProgram, const char *pKernelName,
2929
return offloadResultToUR(Res);
3030
}
3131

32+
Kernel->Program = hProgram;
33+
3234
*phKernel = Kernel;
3335

3436
return UR_RESULT_SUCCESS;
@@ -99,7 +101,8 @@ urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex,
99101
: static_cast<ur_mem_flags_t>(UR_MEM_FLAG_READ_WRITE);
100102
hKernel->Args.addMemObjArg(argIndex, hArgValue, MemAccess);
101103

102-
auto Ptr = std::get<BufferMem>(hArgValue->Mem).Ptr;
104+
auto Ptr =
105+
std::get<BufferMem>(hArgValue->Mem).getPtr(hKernel->Program->Device);
103106
hKernel->Args.addArg(argIndex, sizeof(void *), &Ptr);
104107
return UR_RESULT_SUCCESS;
105108
}

unified-runtime/source/adapters/offload/kernel.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,6 @@ struct ur_kernel_handle_t_ : RefCounted {
7979

8080
ol_kernel_handle_t OffloadKernel;
8181
OffloadKernelArguments Args{};
82+
83+
ur_program_handle_t Program;
8284
};

unified-runtime/source/adapters/offload/memory.cpp

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,45 @@
1919
#include "memory.hpp"
2020
#include "ur2offload.hpp"
2121

22+
void *BufferMem::getPtr(ur_device_handle_t Device) const noexcept {
23+
// Create the allocation for this device if needed
24+
OuterMemStruct->prepareDeviceAllocation(Device);
25+
return Ptrs[OuterMemStruct->Context->getDeviceIndex(Device)];
26+
}
27+
28+
ur_result_t enqueueMigrateBufferToDevice(ur_mem_handle_t Mem,
29+
ur_device_handle_t Device,
30+
ol_queue_handle_t Queue) {
31+
auto &Buffer = std::get<BufferMem>(Mem->Mem);
32+
if (Mem->LastQueueWritingToMemObj == nullptr) {
33+
// Device allocation being initialized from host for the first time
34+
if (Buffer.HostPtr) {
35+
OL_RETURN_ON_ERR(
36+
olMemcpy(Queue, Buffer.getPtr(Device), Device->OffloadDevice,
37+
Buffer.HostPtr, Adapter->HostDevice, Buffer.Size, nullptr));
38+
}
39+
} else if (Mem->LastQueueWritingToMemObj->Device != Device) {
40+
auto LastDevice = Mem->LastQueueWritingToMemObj->Device;
41+
OL_RETURN_ON_ERR(olMemcpy(Queue, Buffer.getPtr(Device), Device->OffloadDevice,
42+
Buffer.getPtr(LastDevice), LastDevice->OffloadDevice,
43+
Buffer.Size, nullptr));
44+
}
45+
return UR_RESULT_SUCCESS;
46+
}
47+
48+
// TODO: Check lock in cuda adapter
49+
ur_result_t ur_mem_handle_t_::enqueueMigrateMemoryToDeviceIfNeeded(
50+
const ur_device_handle_t Device, ol_queue_handle_t Queue) {
51+
UR_ASSERT(Device, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
52+
// Device allocation has already been initialized with most up to date
53+
// data in buffer
54+
if (DeviceIsUpToDate[getContext()->getDeviceIndex(Device)]) {
55+
return UR_RESULT_SUCCESS;
56+
}
57+
58+
return enqueueMigrateBufferToDevice(this, Device, Queue);
59+
}
60+
2261
UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
2362
ur_context_handle_t hContext, ur_mem_flags_t flags, size_t size,
2463
const ur_buffer_properties_t *pProperties, ur_mem_handle_t *phBuffer) {
@@ -29,35 +68,36 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
2968
(flags & UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER) ||
3069
(flags & UR_MEM_FLAG_USE_HOST_POINTER);
3170

32-
void *Ptr = nullptr;
3371
auto HostPtr = pProperties ? pProperties->pHost : nullptr;
34-
auto OffloadDevice = hContext->Device->OffloadDevice;
3572
auto AllocMode = BufferMem::AllocMode::Default;
3673

3774
if (flags & UR_MEM_FLAG_ALLOC_HOST_POINTER) {
38-
OL_RETURN_ON_ERR(
39-
olMemAlloc(OffloadDevice, OL_ALLOC_TYPE_HOST, size, &HostPtr));
75+
// Allocate on the first device, which will be valid on all devices in the
76+
// context
77+
OL_RETURN_ON_ERR(olMemAlloc(hContext->Devices[0]->OffloadDevice,
78+
OL_ALLOC_TYPE_HOST, size, &HostPtr));
4079

4180
// TODO: We (probably) need something like cuMemHostGetDevicePointer
4281
// for this to work everywhere. For now assume the managed host pointer is
4382
// device-accessible.
44-
Ptr = HostPtr;
4583
AllocMode = BufferMem::AllocMode::AllocHostPtr;
4684
} else {
47-
OL_RETURN_ON_ERR(
48-
olMemAlloc(OffloadDevice, OL_ALLOC_TYPE_DEVICE, size, &Ptr));
4985
if (flags & UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER) {
5086
AllocMode = BufferMem::AllocMode::CopyIn;
5187
}
5288
}
5389

5490
ur_mem_handle_t ParentBuffer = nullptr;
5591
auto URMemObj = std::unique_ptr<ur_mem_handle_t_>(new ur_mem_handle_t_{
56-
hContext, ParentBuffer, flags, AllocMode, Ptr, HostPtr, size});
57-
58-
if (PerformInitialCopy) {
59-
OL_RETURN_ON_ERR(olMemcpy(nullptr, Ptr, OffloadDevice, HostPtr,
60-
Adapter->HostDevice, size, nullptr));
92+
hContext, ParentBuffer, flags, AllocMode, HostPtr, size});
93+
94+
if (PerformInitialCopy && HostPtr) {
95+
// Copy per device
96+
for (auto Device : hContext->Devices) {
97+
const auto &Ptr = std::get<BufferMem>(URMemObj->Mem).getPtr(Device);
98+
OL_RETURN_ON_ERR(olMemcpy(nullptr, Ptr, Device->OffloadDevice, HostPtr,
99+
Adapter->HostDevice, size, nullptr));
100+
}
61101
}
62102

63103
*phBuffer = URMemObj.release();
@@ -79,7 +119,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
79119
if (hMem->MemType == ur_mem_handle_t_::Type::Buffer) {
80120
// TODO: Handle registered host memory
81121
auto &BufferImpl = std::get<BufferMem>(MemObjPtr->Mem);
82-
OL_RETURN_ON_ERR(olMemFree(BufferImpl.Ptr));
122+
for (auto *Ptr : BufferImpl.Ptrs) {
123+
if (Ptr) {
124+
OL_RETURN_ON_ERR(olMemFree(Ptr));
125+
}
126+
}
83127
}
84128

85129
return UR_RESULT_SUCCESS;

0 commit comments

Comments
 (0)