Skip to content

[UR][Offload] Initial support for multi-device contexts #19369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions unified-runtime/source/adapters/offload/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@
UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
uint32_t DeviceCount, const ur_device_handle_t *phDevices,
const ur_context_properties_t *, ur_context_handle_t *phContext) {
if (DeviceCount > 1) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;

// For multi-device contexts, all devices must have the same platform.
ur_device_handle_t FirstDevice = *phDevices;
for (uint32_t i = 1; i < DeviceCount; i++) {
if (phDevices[i]->Platform != FirstDevice->Platform) {
return UR_RESULT_ERROR_INVALID_DEVICE;
}
}

auto Ctx = new ur_context_handle_t_(*phDevices);
auto Ctx = new ur_context_handle_t_(phDevices, DeviceCount);
*phContext = Ctx;
return UR_RESULT_SUCCESS;
}
Expand All @@ -30,9 +35,9 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,

switch (propName) {
case UR_CONTEXT_INFO_NUM_DEVICES:
return ReturnValue(uint32_t{1});
return ReturnValue(hContext->Devices.size());
case UR_CONTEXT_INFO_DEVICES:
return ReturnValue(&hContext->Device, 1);
return ReturnValue(hContext->Devices.data(), hContext->Devices.size());
case UR_CONTEXT_INFO_REFERENCE_COUNT:
return ReturnValue(hContext->RefCount.load());
case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT:
Expand Down
27 changes: 22 additions & 5 deletions unified-runtime/source/adapters/offload/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,28 @@
#include <ur_api.h>

struct ur_context_handle_t_ : RefCounted {
ur_context_handle_t_(ur_device_handle_t hDevice) : Device{hDevice} {
urDeviceRetain(Device);
ur_context_handle_t_(const ur_device_handle_t *Devs, size_t NumDevices)
: Devices{Devs, Devs + NumDevices} {
for (auto Device : Devices) {
urDeviceRetain(Device);
}
}
~ur_context_handle_t_() { urDeviceRelease(Device); }
~ur_context_handle_t_() {
for (auto Device : Devices) {
urDeviceRelease(Device);
}
}

std::vector<ur_device_handle_t> Devices;

ur_device_handle_t Device;
std::unordered_map<void *, ol_alloc_type_t> AllocTypeMap;
// Gets the index of the device relative to other devices in the context
size_t getDeviceIndex(ur_device_handle_t hDevice) {
auto It = std::find(Devices.begin(), Devices.end(), hDevice);
assert(It != Devices.end());
return std::distance(Devices.begin(), It);
}

bool containsDevice(ur_device_handle_t Device) {
return std::find(Devices.begin(), Devices.end(), Device) != Devices.end();
}
};
27 changes: 23 additions & 4 deletions unified-runtime/source/adapters/offload/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
LaunchArgs.GroupSize.z = GroupSize[2];
LaunchArgs.DynSharedMemory = 0;

// Prepare memobj arguments
for (auto &Arg : hKernel->Args.MemObjArgs) {
Arg.Mem->enqueueMigrateMemoryToDeviceIfNeeded(hQueue->Device,
hQueue->OffloadQueue);
if (Arg.AccessFlags & (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
Arg.Mem->setLastQueueWritingToMemObj(hQueue);
}
}

ol_event_handle_t EventOut;
OL_RETURN_ON_ERR(
olLaunchKernel(hQueue->OffloadQueue, hQueue->OffloadDevice,
Expand Down Expand Up @@ -105,8 +114,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(

ol_event_handle_t EventOut = nullptr;

char *DevPtr =
reinterpret_cast<char *>(std::get<BufferMem>(hBuffer->Mem).Ptr);
// Note that this entry point may be called on a queue that may not be the
// last queue to write to the MemBuffer, meaning we must perform the copy
// from a different device
// TODO: Evaluate whether this is better than just migrating the memory to the
// correct device and then doing the read.
if (hBuffer->LastQueueWritingToMemObj &&
hBuffer->LastQueueWritingToMemObj->Device != hQueue->Device) {
hQueue = hBuffer->LastQueueWritingToMemObj;
}

char *DevPtr = reinterpret_cast<char *>(
std::get<BufferMem>(hBuffer->Mem).getPtr(hQueue->Device));

OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, pDst, Adapter->HostDevice,
DevPtr + offset, hQueue->OffloadDevice, size,
Expand Down Expand Up @@ -137,8 +156,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(

ol_event_handle_t EventOut = nullptr;

char *DevPtr =
reinterpret_cast<char *>(std::get<BufferMem>(hBuffer->Mem).Ptr);
char *DevPtr = reinterpret_cast<char *>(
std::get<BufferMem>(hBuffer->Mem).getPtr(hQueue->Device));

OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, DevPtr + offset,
hQueue->OffloadDevice, pSrc, Adapter->HostDevice,
Expand Down
5 changes: 4 additions & 1 deletion unified-runtime/source/adapters/offload/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ urKernelCreate(ur_program_handle_t hProgram, const char *pKernelName,
return offloadResultToUR(Res);
}

Kernel->Program = hProgram;

*phKernel = Kernel;

return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -99,7 +101,8 @@ urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex,
: static_cast<ur_mem_flags_t>(UR_MEM_FLAG_READ_WRITE);
hKernel->Args.addMemObjArg(argIndex, hArgValue, MemAccess);

auto Ptr = std::get<BufferMem>(hArgValue->Mem).Ptr;
auto Ptr =
std::get<BufferMem>(hArgValue->Mem).getPtr(hKernel->Program->Device);
hKernel->Args.addArg(argIndex, sizeof(void *), &Ptr);
return UR_RESULT_SUCCESS;
}
Expand Down
2 changes: 2 additions & 0 deletions unified-runtime/source/adapters/offload/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,6 @@ struct ur_kernel_handle_t_ : RefCounted {

ol_kernel_handle_t OffloadKernel;
OffloadKernelArguments Args{};

ur_program_handle_t Program;
};
70 changes: 57 additions & 13 deletions unified-runtime/source/adapters/offload/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,45 @@
#include "memory.hpp"
#include "ur2offload.hpp"

void *BufferMem::getPtr(ur_device_handle_t Device) const noexcept {
// Create the allocation for this device if needed
OuterMemStruct->prepareDeviceAllocation(Device);
return Ptrs[OuterMemStruct->Context->getDeviceIndex(Device)];
}

ur_result_t enqueueMigrateBufferToDevice(ur_mem_handle_t Mem,
ur_device_handle_t Device,
ol_queue_handle_t Queue) {
auto &Buffer = std::get<BufferMem>(Mem->Mem);
if (Mem->LastQueueWritingToMemObj == nullptr) {
// Device allocation being initialized from host for the first time
if (Buffer.HostPtr) {
OL_RETURN_ON_ERR(olMemcpy(Queue, Buffer.getPtr(Device),
Device->OffloadDevice, Buffer.HostPtr,
Adapter->HostDevice, Buffer.Size, nullptr));
}
} else if (Mem->LastQueueWritingToMemObj->Device != Device) {
auto LastDevice = Mem->LastQueueWritingToMemObj->Device;
OL_RETURN_ON_ERR(olMemcpy(Queue, Buffer.getPtr(Device),
Device->OffloadDevice, Buffer.getPtr(LastDevice),
LastDevice->OffloadDevice, Buffer.Size, nullptr));
}
return UR_RESULT_SUCCESS;
}

// TODO: Check lock in cuda adapter
ur_result_t ur_mem_handle_t_::enqueueMigrateMemoryToDeviceIfNeeded(
const ur_device_handle_t Device, ol_queue_handle_t Queue) {
UR_ASSERT(Device, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
// Device allocation has already been initialized with most up to date
// data in buffer
if (DeviceIsUpToDate[getContext()->getDeviceIndex(Device)]) {
return UR_RESULT_SUCCESS;
}

return enqueueMigrateBufferToDevice(this, Device, Queue);
}

UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
ur_context_handle_t hContext, ur_mem_flags_t flags, size_t size,
const ur_buffer_properties_t *pProperties, ur_mem_handle_t *phBuffer) {
Expand All @@ -29,35 +68,36 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
(flags & UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER) ||
(flags & UR_MEM_FLAG_USE_HOST_POINTER);

void *Ptr = nullptr;
auto HostPtr = pProperties ? pProperties->pHost : nullptr;
auto OffloadDevice = hContext->Device->OffloadDevice;
auto AllocMode = BufferMem::AllocMode::Default;

if (flags & UR_MEM_FLAG_ALLOC_HOST_POINTER) {
OL_RETURN_ON_ERR(
olMemAlloc(OffloadDevice, OL_ALLOC_TYPE_HOST, size, &HostPtr));
// Allocate on the first device, which will be valid on all devices in the
// context
OL_RETURN_ON_ERR(olMemAlloc(hContext->Devices[0]->OffloadDevice,
OL_ALLOC_TYPE_HOST, size, &HostPtr));

// TODO: We (probably) need something like cuMemHostGetDevicePointer
// for this to work everywhere. For now assume the managed host pointer is
// device-accessible.
Ptr = HostPtr;
AllocMode = BufferMem::AllocMode::AllocHostPtr;
} else {
OL_RETURN_ON_ERR(
olMemAlloc(OffloadDevice, OL_ALLOC_TYPE_DEVICE, size, &Ptr));
if (flags & UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER) {
AllocMode = BufferMem::AllocMode::CopyIn;
}
}

ur_mem_handle_t ParentBuffer = nullptr;
auto URMemObj = std::unique_ptr<ur_mem_handle_t_>(new ur_mem_handle_t_{
hContext, ParentBuffer, flags, AllocMode, Ptr, HostPtr, size});

if (PerformInitialCopy) {
OL_RETURN_ON_ERR(olMemcpy(nullptr, Ptr, OffloadDevice, HostPtr,
Adapter->HostDevice, size, nullptr));
hContext, ParentBuffer, flags, AllocMode, HostPtr, size});

if (PerformInitialCopy && HostPtr) {
// Copy per device
for (auto Device : hContext->Devices) {
const auto &Ptr = std::get<BufferMem>(URMemObj->Mem).getPtr(Device);
OL_RETURN_ON_ERR(olMemcpy(nullptr, Ptr, Device->OffloadDevice, HostPtr,
Adapter->HostDevice, size, nullptr));
}
}

*phBuffer = URMemObj.release();
Expand All @@ -79,7 +119,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
if (hMem->MemType == ur_mem_handle_t_::Type::Buffer) {
// TODO: Handle registered host memory
auto &BufferImpl = std::get<BufferMem>(MemObjPtr->Mem);
OL_RETURN_ON_ERR(olMemFree(BufferImpl.Ptr));
for (auto *Ptr : BufferImpl.Ptrs) {
if (Ptr) {
OL_RETURN_ON_ERR(olMemFree(Ptr));
}
}
}

return UR_RESULT_SUCCESS;
Expand Down
Loading
Loading