Skip to content

Commit 736967c

Browse files
committed
[UR][L0] FIx UR_PROGRAM_INFO_BINARIES query
1 parent 00ae3dd commit 736967c

File tree

2 files changed

+70
-43
lines changed

2 files changed

+70
-43
lines changed

unified-runtime/source/adapters/level_zero/program.cpp

Lines changed: 21 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ ur_result_t urProgramCreateWithBinary(
136136
// information to distinguish the cases.
137137
try {
138138
for (uint32_t i = 0; i < numDevices; i++) {
139-
UR_ASSERT(ppBinaries[i] || !pLengths[0], UR_RESULT_ERROR_INVALID_VALUE);
139+
UR_ASSERT(ppBinaries[i] || !pLengths[i], UR_RESULT_ERROR_INVALID_VALUE);
140140
UR_ASSERT(hContext->isValidDevice(phDevices[i]),
141141
UR_RESULT_ERROR_INVALID_DEVICE);
142142
}
@@ -746,62 +746,40 @@ ur_result_t urProgramGetInfo(
746746
return ReturnValue(binarySizes.data(), binarySizes.size());
747747
}
748748
case UR_PROGRAM_INFO_BINARIES: {
749-
// The caller sets "ParamValue" to an array of pointers, one for each
750-
// device.
751-
uint8_t **PBinary = nullptr;
752-
if (ProgramInfo) {
753-
PBinary = ur_cast<uint8_t **>(ProgramInfo);
754-
if (!PBinary[0]) {
755-
break;
756-
}
757-
}
758749
std::shared_lock<ur_shared_mutex> Guard(Program->Mutex);
759-
uint8_t *NativeBinaryPtr = nullptr;
760-
if (PBinary) {
761-
NativeBinaryPtr = PBinary[0];
750+
size_t NumDevices = Program->AssociatedDevices.size();
751+
if (PropSizeRet) {
752+
// Return the size of the array of pointers to binaries (for each device).
753+
*PropSizeRet = NumDevices * sizeof(uint8_t *);
762754
}
763755

764-
size_t SzBinary = 0;
765-
for (uint32_t deviceIndex = 0;
766-
deviceIndex < Program->AssociatedDevices.size(); deviceIndex++) {
756+
// If the caller did not provide an array of pointers to copy binaries into,
757+
// return early.
758+
if (!ProgramInfo)
759+
break;
760+
761+
// If the caller provided an array of pointers, copy the binaries.
762+
uint8_t **DestBinPtrs = ur_cast<uint8_t **>(ProgramInfo);
763+
for (uint32_t deviceIndex = 0; deviceIndex < NumDevices; deviceIndex++) {
764+
uint8_t *DestBinPtr = DestBinPtrs[deviceIndex];
765+
if (!DestBinPtr)
766+
continue;
767+
767768
auto ZeDevice = Program->AssociatedDevices[deviceIndex]->ZeDevice;
768769
auto State = Program->getState(ZeDevice);
769770
if (State == ur_program_handle_t_::Native) {
770771
// If Program was created from Native code then return that code.
771-
if (PBinary) {
772-
std::memcpy(PBinary[deviceIndex], Program->getCode(ZeDevice),
773-
Program->getCodeSize(ZeDevice));
774-
}
775-
SzBinary += Program->getCodeSize(ZeDevice);
776-
continue;
777-
}
778-
if (State == ur_program_handle_t_::IL ||
779-
State == ur_program_handle_t_::Object) {
780-
// We don't have a binary for this device, so don't update the output
781-
// pointer to the binary, only set return size to 0.
782-
if (PropSizeRet)
783-
*PropSizeRet = 0;
772+
std::memcpy(DestBinPtr, Program->getCode(ZeDevice),
773+
Program->getCodeSize(ZeDevice));
784774
} else if (State == ur_program_handle_t_::Exe) {
785775
auto ZeModule = Program->getZeModuleHandle(ZeDevice);
786776
if (!ZeModule) {
787777
return UR_RESULT_ERROR_INVALID_PROGRAM;
788778
}
789-
size_t binarySize = 0;
790-
if (PBinary) {
791-
NativeBinaryPtr = PBinary[deviceIndex];
792-
}
793-
// If the caller is using a Program which is a built binary, then
794-
// the program returned will either be a single module if this is a
795-
// native binary or the native binary for each device will be returned.
796-
ZE2UR_CALL(zeModuleGetNativeBinary,
797-
(ZeModule, &binarySize, NativeBinaryPtr));
798-
SzBinary += binarySize;
799-
} else {
800-
return UR_RESULT_ERROR_INVALID_PROGRAM;
779+
size_t DummySize;
780+
ZE2UR_CALL(zeModuleGetNativeBinary, (ZeModule, &DummySize, DestBinPtr));
801781
}
802782
}
803-
if (PropSizeRet)
804-
*PropSizeRet = SzBinary;
805783
break;
806784
}
807785
case UR_PROGRAM_INFO_NUM_KERNELS: {

unified-runtime/test/conformance/program/urMultiDeviceProgramCreateWithIL.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,52 @@ TEST_P(urMultiDeviceProgramTest, urMultiDeviceProgramGetInfo) {
6464
ASSERT_EQ(binaries[i].size(), 0);
6565
}
6666
}
67+
68+
// Build program for the second device only and check validity of the binary returned by urProgramGetInfo
69+
// by recreating program from the binary and building it.
70+
TEST_P(urMultiDeviceProgramTest, urMultiDeviceProgramGetInfoBinaries) {
71+
ur_backend_t backend;
72+
ASSERT_SUCCESS(urPlatformGetInfo(platform, UR_PLATFORM_INFO_BACKEND,
73+
sizeof(backend), &backend, nullptr));
74+
if (backend != UR_BACKEND_LEVEL_ZERO) {
75+
GTEST_SKIP();
76+
}
77+
std::vector<ur_device_handle_t> associated_devices(devices.size());
78+
ASSERT_SUCCESS(
79+
urProgramGetInfo(program, UR_PROGRAM_INFO_DEVICES,
80+
associated_devices.size() * sizeof(ur_device_handle_t),
81+
associated_devices.data(), nullptr));
82+
if (associated_devices.size() < 2) {
83+
GTEST_SKIP();
84+
}
85+
86+
// Build program for the second device only.
87+
auto subset = std::vector<ur_device_handle_t>(associated_devices.begin() + 1,
88+
associated_devices.begin() + 2);
89+
ASSERT_SUCCESS(
90+
urProgramBuildExp(program, subset.size(), subset.data(), nullptr));
91+
std::vector<size_t> binary_sizes(associated_devices.size());
92+
ASSERT_SUCCESS(urProgramGetInfo(program, UR_PROGRAM_INFO_BINARY_SIZES,
93+
binary_sizes.size() * sizeof(size_t),
94+
binary_sizes.data(), nullptr));
95+
std::vector<std::vector<uint8_t>> binaries(associated_devices.size());
96+
std::vector<const uint8_t *> pointers(associated_devices.size());
97+
for (size_t i = 0; i < associated_devices.size(); i++) {
98+
binaries[i].resize(binary_sizes[i]);
99+
pointers[i] = binaries[i].data();
100+
}
101+
102+
ASSERT_SUCCESS(urProgramGetInfo(program, UR_PROGRAM_INFO_BINARIES,
103+
sizeof(uint8_t *) * pointers.size(),
104+
pointers.data(), nullptr));
105+
106+
// Now create program from the obtained binary and build to check validity.
107+
ur_program_handle_t program_from_binary = nullptr;
108+
ASSERT_SUCCESS(urProgramCreateWithBinary(
109+
context, 1, associated_devices.data() + 1, binary_sizes.data() + 1,
110+
pointers.data() + 1, nullptr, &program_from_binary));
111+
ASSERT_NE(program_from_binary, nullptr);
112+
ASSERT_SUCCESS(urProgramBuildExp(program_from_binary, 1,
113+
associated_devices.data() + 1, nullptr));
114+
ASSERT_SUCCESS(urProgramRelease(program_from_binary));
115+
}

0 commit comments

Comments
 (0)