diff --git a/unified-runtime/source/adapters/offload/context.cpp b/unified-runtime/source/adapters/offload/context.cpp index 2dcbcd4da82f5..9b8fe57347141 100644 --- a/unified-runtime/source/adapters/offload/context.cpp +++ b/unified-runtime/source/adapters/offload/context.cpp @@ -14,11 +14,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate( uint32_t DeviceCount, const ur_device_handle_t *phDevices, const ur_context_properties_t *, ur_context_handle_t *phContext) { - if (DeviceCount > 1) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + + // For multi-device contexts, all devices must have the same platform. + ur_device_handle_t FirstDevice = *phDevices; + for (uint32_t i = 1; i < DeviceCount; i++) { + if (phDevices[i]->Platform != FirstDevice->Platform) { + return UR_RESULT_ERROR_INVALID_DEVICE; + } } - auto Ctx = new ur_context_handle_t_(*phDevices); + auto Ctx = new ur_context_handle_t_(phDevices, DeviceCount); *phContext = Ctx; return UR_RESULT_SUCCESS; } @@ -30,9 +35,9 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName, switch (propName) { case UR_CONTEXT_INFO_NUM_DEVICES: - return ReturnValue(uint32_t{1}); + return ReturnValue(hContext->Devices.size()); case UR_CONTEXT_INFO_DEVICES: - return ReturnValue(&hContext->Device, 1); + return ReturnValue(hContext->Devices.data(), hContext->Devices.size()); case UR_CONTEXT_INFO_REFERENCE_COUNT: return ReturnValue(hContext->RefCount.load()); case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT: diff --git a/unified-runtime/source/adapters/offload/context.hpp b/unified-runtime/source/adapters/offload/context.hpp index 38857446c47f8..9ad64865f5b29 100644 --- a/unified-runtime/source/adapters/offload/context.hpp +++ b/unified-runtime/source/adapters/offload/context.hpp @@ -18,11 +18,28 @@ #include struct ur_context_handle_t_ : RefCounted { - ur_context_handle_t_(ur_device_handle_t hDevice) : Device{hDevice} { - urDeviceRetain(Device); + ur_context_handle_t_(const ur_device_handle_t *Devs, size_t NumDevices) + : Devices{Devs, Devs + NumDevices} { + for (auto Device : Devices) { + urDeviceRetain(Device); + } } - ~ur_context_handle_t_() { urDeviceRelease(Device); } + ~ur_context_handle_t_() { + for (auto Device : Devices) { + urDeviceRelease(Device); + } + } + + std::vector Devices; - ur_device_handle_t Device; - std::unordered_map AllocTypeMap; + // Gets the index of the device relative to other devices in the context + size_t getDeviceIndex(ur_device_handle_t hDevice) { + auto It = std::find(Devices.begin(), Devices.end(), hDevice); + assert(It != Devices.end()); + return std::distance(Devices.begin(), It); + } + + bool containsDevice(ur_device_handle_t Device) { + return std::find(Devices.begin(), Devices.end(), Device) != Devices.end(); + } }; diff --git a/unified-runtime/source/adapters/offload/enqueue.cpp b/unified-runtime/source/adapters/offload/enqueue.cpp index 9b5cd9140a0f1..b076d6d3d9c8c 100644 --- a/unified-runtime/source/adapters/offload/enqueue.cpp +++ b/unified-runtime/source/adapters/offload/enqueue.cpp @@ -67,6 +67,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( LaunchArgs.GroupSize.z = GroupSize[2]; LaunchArgs.DynSharedMemory = 0; + // Prepare memobj arguments + for (auto &Arg : hKernel->Args.MemObjArgs) { + Arg.Mem->enqueueMigrateMemoryToDeviceIfNeeded(hQueue->Device, + hQueue->OffloadQueue); + if (Arg.AccessFlags & (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) { + Arg.Mem->setLastQueueWritingToMemObj(hQueue); + } + } + ol_event_handle_t EventOut; OL_RETURN_ON_ERR( olLaunchKernel(hQueue->OffloadQueue, hQueue->OffloadDevice, @@ -105,8 +114,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead( ol_event_handle_t EventOut = nullptr; - char *DevPtr = - reinterpret_cast(std::get(hBuffer->Mem).Ptr); + // Note that this entry point may be called on a queue that may not be the + // last queue to write to the MemBuffer, meaning we must perform the copy + // from a different device + // TODO: Evaluate whether this is better than just migrating the memory to the + // correct device and then doing the read. + if (hBuffer->LastQueueWritingToMemObj && + hBuffer->LastQueueWritingToMemObj->Device != hQueue->Device) { + hQueue = hBuffer->LastQueueWritingToMemObj; + } + + char *DevPtr = reinterpret_cast( + std::get(hBuffer->Mem).getPtr(hQueue->Device)); OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, pDst, Adapter->HostDevice, DevPtr + offset, hQueue->OffloadDevice, size, @@ -137,8 +156,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite( ol_event_handle_t EventOut = nullptr; - char *DevPtr = - reinterpret_cast(std::get(hBuffer->Mem).Ptr); + char *DevPtr = reinterpret_cast( + std::get(hBuffer->Mem).getPtr(hQueue->Device)); OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, DevPtr + offset, hQueue->OffloadDevice, pSrc, Adapter->HostDevice, diff --git a/unified-runtime/source/adapters/offload/kernel.cpp b/unified-runtime/source/adapters/offload/kernel.cpp index b9e9152d437a2..8a3bc946ad019 100644 --- a/unified-runtime/source/adapters/offload/kernel.cpp +++ b/unified-runtime/source/adapters/offload/kernel.cpp @@ -29,6 +29,8 @@ urKernelCreate(ur_program_handle_t hProgram, const char *pKernelName, return offloadResultToUR(Res); } + Kernel->Program = hProgram; + *phKernel = Kernel; return UR_RESULT_SUCCESS; @@ -99,7 +101,8 @@ urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex, : static_cast(UR_MEM_FLAG_READ_WRITE); hKernel->Args.addMemObjArg(argIndex, hArgValue, MemAccess); - auto Ptr = std::get(hArgValue->Mem).Ptr; + auto Ptr = + std::get(hArgValue->Mem).getPtr(hKernel->Program->Device); hKernel->Args.addArg(argIndex, sizeof(void *), &Ptr); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/offload/kernel.hpp b/unified-runtime/source/adapters/offload/kernel.hpp index e8ff732d700f0..8aa55fc33cdbe 100644 --- a/unified-runtime/source/adapters/offload/kernel.hpp +++ b/unified-runtime/source/adapters/offload/kernel.hpp @@ -79,4 +79,6 @@ struct ur_kernel_handle_t_ : RefCounted { ol_kernel_handle_t OffloadKernel; OffloadKernelArguments Args{}; + + ur_program_handle_t Program; }; diff --git a/unified-runtime/source/adapters/offload/memory.cpp b/unified-runtime/source/adapters/offload/memory.cpp index 564e616a973ab..b16074b9ebbf5 100644 --- a/unified-runtime/source/adapters/offload/memory.cpp +++ b/unified-runtime/source/adapters/offload/memory.cpp @@ -19,6 +19,45 @@ #include "memory.hpp" #include "ur2offload.hpp" +void *BufferMem::getPtr(ur_device_handle_t Device) const noexcept { + // Create the allocation for this device if needed + OuterMemStruct->prepareDeviceAllocation(Device); + return Ptrs[OuterMemStruct->Context->getDeviceIndex(Device)]; +} + +ur_result_t enqueueMigrateBufferToDevice(ur_mem_handle_t Mem, + ur_device_handle_t Device, + ol_queue_handle_t Queue) { + auto &Buffer = std::get(Mem->Mem); + if (Mem->LastQueueWritingToMemObj == nullptr) { + // Device allocation being initialized from host for the first time + if (Buffer.HostPtr) { + OL_RETURN_ON_ERR(olMemcpy(Queue, Buffer.getPtr(Device), + Device->OffloadDevice, Buffer.HostPtr, + Adapter->HostDevice, Buffer.Size, nullptr)); + } + } else if (Mem->LastQueueWritingToMemObj->Device != Device) { + auto LastDevice = Mem->LastQueueWritingToMemObj->Device; + OL_RETURN_ON_ERR(olMemcpy(Queue, Buffer.getPtr(Device), + Device->OffloadDevice, Buffer.getPtr(LastDevice), + LastDevice->OffloadDevice, Buffer.Size, nullptr)); + } + return UR_RESULT_SUCCESS; +} + +// TODO: Check lock in cuda adapter +ur_result_t ur_mem_handle_t_::enqueueMigrateMemoryToDeviceIfNeeded( + const ur_device_handle_t Device, ol_queue_handle_t Queue) { + UR_ASSERT(Device, UR_RESULT_ERROR_INVALID_NULL_HANDLE); + // Device allocation has already been initialized with most up to date + // data in buffer + if (DeviceIsUpToDate[getContext()->getDeviceIndex(Device)]) { + return UR_RESULT_SUCCESS; + } + + return enqueueMigrateBufferToDevice(this, Device, Queue); +} + UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate( ur_context_handle_t hContext, ur_mem_flags_t flags, size_t size, const ur_buffer_properties_t *pProperties, ur_mem_handle_t *phBuffer) { @@ -29,23 +68,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate( (flags & UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER) || (flags & UR_MEM_FLAG_USE_HOST_POINTER); - void *Ptr = nullptr; auto HostPtr = pProperties ? pProperties->pHost : nullptr; - auto OffloadDevice = hContext->Device->OffloadDevice; auto AllocMode = BufferMem::AllocMode::Default; if (flags & UR_MEM_FLAG_ALLOC_HOST_POINTER) { - OL_RETURN_ON_ERR( - olMemAlloc(OffloadDevice, OL_ALLOC_TYPE_HOST, size, &HostPtr)); + // Allocate on the first device, which will be valid on all devices in the + // context + OL_RETURN_ON_ERR(olMemAlloc(hContext->Devices[0]->OffloadDevice, + OL_ALLOC_TYPE_HOST, size, &HostPtr)); // TODO: We (probably) need something like cuMemHostGetDevicePointer // for this to work everywhere. For now assume the managed host pointer is // device-accessible. - Ptr = HostPtr; AllocMode = BufferMem::AllocMode::AllocHostPtr; } else { - OL_RETURN_ON_ERR( - olMemAlloc(OffloadDevice, OL_ALLOC_TYPE_DEVICE, size, &Ptr)); if (flags & UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER) { AllocMode = BufferMem::AllocMode::CopyIn; } @@ -53,11 +89,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate( ur_mem_handle_t ParentBuffer = nullptr; auto URMemObj = std::unique_ptr(new ur_mem_handle_t_{ - hContext, ParentBuffer, flags, AllocMode, Ptr, HostPtr, size}); - - if (PerformInitialCopy) { - OL_RETURN_ON_ERR(olMemcpy(nullptr, Ptr, OffloadDevice, HostPtr, - Adapter->HostDevice, size, nullptr)); + hContext, ParentBuffer, flags, AllocMode, HostPtr, size}); + + if (PerformInitialCopy && HostPtr) { + // Copy per device + for (auto Device : hContext->Devices) { + const auto &Ptr = std::get(URMemObj->Mem).getPtr(Device); + OL_RETURN_ON_ERR(olMemcpy(nullptr, Ptr, Device->OffloadDevice, HostPtr, + Adapter->HostDevice, size, nullptr)); + } } *phBuffer = URMemObj.release(); @@ -79,7 +119,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) { if (hMem->MemType == ur_mem_handle_t_::Type::Buffer) { // TODO: Handle registered host memory auto &BufferImpl = std::get(MemObjPtr->Mem); - OL_RETURN_ON_ERR(olMemFree(BufferImpl.Ptr)); + for (auto *Ptr : BufferImpl.Ptrs) { + if (Ptr) { + OL_RETURN_ON_ERR(olMemFree(Ptr)); + } + } } return UR_RESULT_SUCCESS; diff --git a/unified-runtime/source/adapters/offload/memory.hpp b/unified-runtime/source/adapters/offload/memory.hpp index 8e1666aa36924..e507c88f7cd3b 100644 --- a/unified-runtime/source/adapters/offload/memory.hpp +++ b/unified-runtime/source/adapters/offload/memory.hpp @@ -13,8 +13,27 @@ #include "ur_api.h" #include "common.hpp" +#include "context.hpp" +#include "queue.hpp" +#include "ur2offload.hpp" struct BufferMem { + + enum class BufferStrategy { + // When pinned host memory isn't used, make a separate allocation on each + // device, and migrate data from the last used device to the current device + // when a buffer is used in an enqueue command. + DiscreteDeviceAllocs, + // When pinned host memory isn't used, make a single shared USM allocation + // that is visible on all devices. When a buffer is used in an enqueue + // command, prefetch the data to the active device. + // TODO: Implement this. + SingleSharedAlloc + }; + + // This is the most conventional implementation of buffers. + static constexpr auto Strategy = BufferStrategy::DiscreteDeviceAllocs; + enum class AllocMode { Default, UseHostPtr, @@ -40,8 +59,12 @@ struct BufferMem { }; ur_mem_handle_t Parent; - // Underlying device pointer - void *Ptr; + // Outer UR mem holding this BufferMem in variant + ur_mem_handle_t OuterMemStruct; + + // Underlying device pointers + std::vector Ptrs; + // Pointer associated with this device on the host void *HostPtr; size_t Size; @@ -49,12 +72,15 @@ struct BufferMem { AllocMode MemAllocMode; std::unordered_map PtrToBufferMap; - BufferMem(ur_mem_handle_t Parent, BufferMem::AllocMode Mode, void *Ptr, + BufferMem(ur_mem_handle_t Parent, ur_mem_handle_t OuterMemStruct, + ur_context_handle_t Context, BufferMem::AllocMode Mode, void *HostPtr, size_t Size) - : Parent{Parent}, Ptr{Ptr}, HostPtr{HostPtr}, Size{Size}, + : Parent{Parent}, OuterMemStruct{OuterMemStruct}, + Ptrs(Context->Devices.size(), nullptr), HostPtr{HostPtr}, Size{Size}, MemAllocMode{Mode} {}; - void *get() const noexcept { return Ptr; } + void *getPtr(ur_device_handle_t Device) const noexcept; + size_t getSize() const noexcept { return Size; } BufferMap *getMapDetails(void *Map) { @@ -95,19 +121,93 @@ struct ur_mem_handle_t_ : RefCounted { enum class Type { Buffer } MemType; ur_mem_flags_t MemFlags; + ur_mutex MemoryAllocationMutex; + // For now we only support BufferMem. Eventually we'll support images, so use // a variant to store the underlying object. std::variant Mem; + // For every device in the context, is it known to have the latest copy of the + // data. Operations modifying the buffer should invalidate it on all devices + // but the the one the operation occurred on. + std::vector DeviceIsUpToDate; + + ur_queue_handle_t LastQueueWritingToMemObj{nullptr}; + ur_mem_handle_t_(ur_context_handle_t Context, ur_mem_handle_t Parent, ur_mem_flags_t MemFlags, BufferMem::AllocMode Mode, - void *Ptr, void *HostPtr, size_t Size) + void *HostPtr, size_t Size) : Context{Context}, MemType{Type::Buffer}, MemFlags{MemFlags}, - Mem{BufferMem{Parent, Mode, Ptr, HostPtr, Size}} { + Mem{BufferMem{Parent, this, Context, Mode, HostPtr, Size}}, + DeviceIsUpToDate(Context->Devices.size(), false) { urContextRetain(Context); }; ~ur_mem_handle_t_() { urContextRelease(Context); } ur_context_handle_t getContext() const noexcept { return Context; } + + void *getDevicePointer(const ur_device_handle_t Device) { + auto DeviceIdx = Context->getDeviceIndex(Device); + auto &Buffer = std::get(Mem); + + if (auto *Ptr = Buffer.Ptrs[DeviceIdx]) { + return Ptr; + } else { + olMemAlloc(Device->OffloadDevice, OL_ALLOC_TYPE_DEVICE, Buffer.Size, + &Buffer.Ptrs[DeviceIdx]); + return Buffer.Ptrs[DeviceIdx]; + } + } + + ur_result_t prepareDeviceAllocation(ur_device_handle_t Device) { + // Lock to prevent duplicate allocations in race conditions + ur_lock LockGuard(MemoryAllocationMutex); + + auto DeviceIdx = Context->getDeviceIndex(Device); + auto &Buffer = std::get(Mem); + auto DevPtr = Buffer.Ptrs[DeviceIdx]; + + // Allocation has already been made + if (DevPtr != nullptr) { + return UR_RESULT_SUCCESS; + } + + if (Buffer.MemAllocMode == BufferMem::AllocMode::AllocHostPtr) { + // Host allocation has already been made by this point. + // TODO: We (probably) need something like cuMemHostGetDevicePointer + // for this to work everywhere. For now assume the managed host pointer is + // always device-accessible. + DevPtr = Buffer.HostPtr; + + } else if (Buffer.MemAllocMode == BufferMem::AllocMode::UseHostPtr) { + // TODO: This code path is never used (same as the cuda adapter) + DevPtr = Buffer.HostPtr; + } else { + auto Res = olMemAlloc(Device->OffloadDevice, OL_ALLOC_TYPE_DEVICE, + Buffer.Size, &DevPtr); + if (Res) { + return offloadResultToUR(Res); + } + } + + Buffer.Ptrs[DeviceIdx] = DevPtr; + return UR_RESULT_SUCCESS; + } + + void setLastQueueWritingToMemObj(ur_queue_handle_t WritingQueue) { + urQueueRetain(WritingQueue); + if (LastQueueWritingToMemObj != nullptr) { + urQueueRelease(LastQueueWritingToMemObj); + } + LastQueueWritingToMemObj = WritingQueue; + for (const auto &Device : Context->Devices) { + DeviceIsUpToDate[Context->getDeviceIndex(Device)] = + Device == WritingQueue->Device; + } + } + + ur_result_t + enqueueMigrateMemoryToDeviceIfNeeded(const ur_device_handle_t Device, + ol_queue_handle_t Queue); }; diff --git a/unified-runtime/source/adapters/offload/program.cpp b/unified-runtime/source/adapters/offload/program.cpp index dde21a20b2c24..289edd4b42d46 100644 --- a/unified-runtime/source/adapters/offload/program.cpp +++ b/unified-runtime/source/adapters/offload/program.cpp @@ -27,7 +27,7 @@ namespace { // Workaround for Offload not supporting PTX binaries. Force CUDA programs // to be linked so they end up as CUBIN. #ifdef UR_CUDA_ENABLED -ur_result_t ProgramCreateCudaWorkaround(ur_context_handle_t hContext, +ur_result_t ProgramCreateCudaWorkaround(ur_device_handle_t hDevice, const uint8_t *Binary, size_t Length, ur_program_handle_t *phProgram) { uint8_t *RealBinary; @@ -49,8 +49,8 @@ ur_result_t ProgramCreateCudaWorkaround(ur_context_handle_t hContext, #endif ur_program_handle_t Program = new ur_program_handle_t_(); - auto Res = olCreateProgram(hContext->Device->OffloadDevice, RealBinary, - RealLength, &Program->OffloadProgram); + auto Res = olCreateProgram(hDevice->OffloadDevice, RealBinary, RealLength, + &Program->OffloadProgram); // Program owns the linked module now cuLinkDestroy(State); @@ -74,8 +74,8 @@ ur_result_t ProgramCreateCudaWorkaround(ur_context_handle_t, const uint8_t *, } // namespace UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary( - ur_context_handle_t hContext, uint32_t numDevices, - ur_device_handle_t *phDevices, size_t *pLengths, const uint8_t **ppBinaries, + ur_context_handle_t, uint32_t numDevices, ur_device_handle_t *phDevices, + size_t *pLengths, const uint8_t **ppBinaries, const ur_program_properties_t *, ur_program_handle_t *phProgram) { if (numDevices > 1) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; @@ -104,12 +104,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary( olGetPlatformInfo(phDevices[0]->Platform->OffloadPlatform, OL_PLATFORM_INFO_BACKEND, sizeof(Backend), &Backend); if (Backend == OL_PLATFORM_BACKEND_CUDA) { - return ProgramCreateCudaWorkaround(hContext, RealBinary, RealLength, + return ProgramCreateCudaWorkaround(phDevices[0], RealBinary, RealLength, phProgram); } ur_program_handle_t Program = new ur_program_handle_t_(); - auto Res = olCreateProgram(hContext->Device->OffloadDevice, RealBinary, + auto Res = olCreateProgram(phDevices[0]->OffloadDevice, RealBinary, RealLength, &Program->OffloadProgram); if (Res != OL_SUCCESS) { @@ -117,6 +117,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary( return offloadResultToUR(Res); } + // We only have one device + Program->Device = phDevices[0]; + *phProgram = Program; return UR_RESULT_SUCCESS; diff --git a/unified-runtime/source/adapters/offload/program.hpp b/unified-runtime/source/adapters/offload/program.hpp index 1d0263aad2998..0ae1dab97ae26 100644 --- a/unified-runtime/source/adapters/offload/program.hpp +++ b/unified-runtime/source/adapters/offload/program.hpp @@ -17,4 +17,5 @@ struct ur_program_handle_t_ : RefCounted { ol_program_handle_t OffloadProgram; + ur_device_handle_t Device; }; diff --git a/unified-runtime/source/adapters/offload/queue.cpp b/unified-runtime/source/adapters/offload/queue.cpp index 57a10fafa05b6..20e3167a392e5 100644 --- a/unified-runtime/source/adapters/offload/queue.cpp +++ b/unified-runtime/source/adapters/offload/queue.cpp @@ -21,7 +21,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreate( [[maybe_unused]] ur_context_handle_t hContext, ur_device_handle_t hDevice, const ur_queue_properties_t *, ur_queue_handle_t *phQueue) { - assert(hContext->Device == hDevice); + if (!hContext->containsDevice(hDevice)) { + return UR_RESULT_ERROR_INVALID_DEVICE; + } ur_queue_handle_t Queue = new ur_queue_handle_t_(); auto Res = olCreateQueue(hDevice->OffloadDevice, &Queue->OffloadQueue); @@ -31,6 +33,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreate( } Queue->OffloadDevice = hDevice->OffloadDevice; + Queue->Device = hDevice; *phQueue = Queue; diff --git a/unified-runtime/source/adapters/offload/queue.hpp b/unified-runtime/source/adapters/offload/queue.hpp index 6afe4bf15098e..6c9e3f51d22c3 100644 --- a/unified-runtime/source/adapters/offload/queue.hpp +++ b/unified-runtime/source/adapters/offload/queue.hpp @@ -18,4 +18,5 @@ struct ur_queue_handle_t_ : RefCounted { ol_queue_handle_t OffloadQueue; ol_device_handle_t OffloadDevice; + ur_device_handle_t Device; }; diff --git a/unified-runtime/source/adapters/offload/usm.cpp b/unified-runtime/source/adapters/offload/usm.cpp index 99f7931e9ddd7..c10e80f228e93 100644 --- a/unified-runtime/source/adapters/offload/usm.cpp +++ b/unified-runtime/source/adapters/offload/usm.cpp @@ -20,30 +20,29 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(ur_context_handle_t hContext, const ur_usm_desc_t *, ur_usm_pool_handle_t, size_t size, void **ppMem) { - OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice, + // Pick any device to do the host alloc, the allocation will be accessible on + // any device in the platform. + OL_RETURN_ON_ERR(olMemAlloc(hContext->Devices[0]->OffloadDevice, OL_ALLOC_TYPE_HOST, size, ppMem)); - hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_HOST); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc( - ur_context_handle_t hContext, ur_device_handle_t, const ur_usm_desc_t *, + ur_context_handle_t, ur_device_handle_t Device, const ur_usm_desc_t *, ur_usm_pool_handle_t, size_t size, void **ppMem) { - OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice, - OL_ALLOC_TYPE_DEVICE, size, ppMem)); + OL_RETURN_ON_ERR( + olMemAlloc(Device->OffloadDevice, OL_ALLOC_TYPE_DEVICE, size, ppMem)); - hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_DEVICE); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc( - ur_context_handle_t hContext, ur_device_handle_t, const ur_usm_desc_t *, + ur_context_handle_t, ur_device_handle_t Device, const ur_usm_desc_t *, ur_usm_pool_handle_t, size_t size, void **ppMem) { - OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice, - OL_ALLOC_TYPE_MANAGED, size, ppMem)); + OL_RETURN_ON_ERR( + olMemAlloc(Device->OffloadDevice, OL_ALLOC_TYPE_MANAGED, size, ppMem)); - hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_MANAGED); return UR_RESULT_SUCCESS; }