@@ -136,7 +136,7 @@ ur_result_t urProgramCreateWithBinary(
136136 // information to distinguish the cases.
137137 try {
138138 for (uint32_t i = 0 ; i < numDevices; i++) {
139- UR_ASSERT (ppBinaries[i] || !pLengths[0 ], UR_RESULT_ERROR_INVALID_VALUE);
139+ UR_ASSERT (ppBinaries[i] || !pLengths[i ], UR_RESULT_ERROR_INVALID_VALUE);
140140 UR_ASSERT (hContext->isValidDevice (phDevices[i]),
141141 UR_RESULT_ERROR_INVALID_DEVICE);
142142 }
@@ -746,62 +746,40 @@ ur_result_t urProgramGetInfo(
746746 return ReturnValue (binarySizes.data (), binarySizes.size ());
747747 }
748748 case UR_PROGRAM_INFO_BINARIES: {
749- // The caller sets "ParamValue" to an array of pointers, one for each
750- // device.
751- uint8_t **PBinary = nullptr ;
752- if (ProgramInfo) {
753- PBinary = ur_cast<uint8_t **>(ProgramInfo);
754- if (!PBinary[0 ]) {
755- break ;
756- }
757- }
758749 std::shared_lock<ur_shared_mutex> Guard (Program->Mutex );
759- uint8_t *NativeBinaryPtr = nullptr ;
760- if (PBinary) {
761- NativeBinaryPtr = PBinary[0 ];
750+ size_t NumDevices = Program->AssociatedDevices .size ();
751+ if (PropSizeRet) {
752+ // Return the size of the array of pointers to binaries (for each device).
753+ *PropSizeRet = NumDevices * sizeof (uint8_t *);
762754 }
763755
764- size_t SzBinary = 0 ;
765- for (uint32_t deviceIndex = 0 ;
766- deviceIndex < Program->AssociatedDevices .size (); deviceIndex++) {
756+ // If the caller did not provide an array of pointers to copy binaries into,
757+ // return early.
758+ if (!ProgramInfo)
759+ break ;
760+
761+ // If the caller provided an array of pointers, copy the binaries.
762+ uint8_t **DestBinPtrs = ur_cast<uint8_t **>(ProgramInfo);
763+ for (uint32_t deviceIndex = 0 ; deviceIndex < NumDevices; deviceIndex++) {
764+ uint8_t *DestBinPtr = DestBinPtrs[deviceIndex];
765+ if (!DestBinPtr)
766+ continue ;
767+
767768 auto ZeDevice = Program->AssociatedDevices [deviceIndex]->ZeDevice ;
768769 auto State = Program->getState (ZeDevice);
769770 if (State == ur_program_handle_t_::Native) {
770771 // If Program was created from Native code then return that code.
771- if (PBinary) {
772- std::memcpy (PBinary[deviceIndex], Program->getCode (ZeDevice),
773- Program->getCodeSize (ZeDevice));
774- }
775- SzBinary += Program->getCodeSize (ZeDevice);
776- continue ;
777- }
778- if (State == ur_program_handle_t_::IL ||
779- State == ur_program_handle_t_::Object) {
780- // We don't have a binary for this device, so don't update the output
781- // pointer to the binary, only set return size to 0.
782- if (PropSizeRet)
783- *PropSizeRet = 0 ;
772+ std::memcpy (DestBinPtr, Program->getCode (ZeDevice),
773+ Program->getCodeSize (ZeDevice));
784774 } else if (State == ur_program_handle_t_::Exe) {
785775 auto ZeModule = Program->getZeModuleHandle (ZeDevice);
786776 if (!ZeModule) {
787777 return UR_RESULT_ERROR_INVALID_PROGRAM;
788778 }
789- size_t binarySize = 0 ;
790- if (PBinary) {
791- NativeBinaryPtr = PBinary[deviceIndex];
792- }
793- // If the caller is using a Program which is a built binary, then
794- // the program returned will either be a single module if this is a
795- // native binary or the native binary for each device will be returned.
796- ZE2UR_CALL (zeModuleGetNativeBinary,
797- (ZeModule, &binarySize, NativeBinaryPtr));
798- SzBinary += binarySize;
799- } else {
800- return UR_RESULT_ERROR_INVALID_PROGRAM;
779+ size_t DummySize;
780+ ZE2UR_CALL (zeModuleGetNativeBinary, (ZeModule, &DummySize, DestBinPtr));
801781 }
802782 }
803- if (PropSizeRet)
804- *PropSizeRet = SzBinary;
805783 break ;
806784 }
807785 case UR_PROGRAM_INFO_NUM_KERNELS: {
0 commit comments