@@ -136,7 +136,7 @@ ur_result_t urProgramCreateWithBinary(
136
136
// information to distinguish the cases.
137
137
try {
138
138
for (uint32_t i = 0 ; i < numDevices; i++) {
139
- UR_ASSERT (ppBinaries[i] || !pLengths[0 ], UR_RESULT_ERROR_INVALID_VALUE);
139
+ UR_ASSERT (ppBinaries[i] || !pLengths[i ], UR_RESULT_ERROR_INVALID_VALUE);
140
140
UR_ASSERT (hContext->isValidDevice (phDevices[i]),
141
141
UR_RESULT_ERROR_INVALID_DEVICE);
142
142
}
@@ -746,62 +746,40 @@ ur_result_t urProgramGetInfo(
746
746
return ReturnValue (binarySizes.data (), binarySizes.size ());
747
747
}
748
748
case UR_PROGRAM_INFO_BINARIES: {
749
- // The caller sets "ParamValue" to an array of pointers, one for each
750
- // device.
751
- uint8_t **PBinary = nullptr ;
752
- if (ProgramInfo) {
753
- PBinary = ur_cast<uint8_t **>(ProgramInfo);
754
- if (!PBinary[0 ]) {
755
- break ;
756
- }
757
- }
758
749
std::shared_lock<ur_shared_mutex> Guard (Program->Mutex );
759
- uint8_t *NativeBinaryPtr = nullptr ;
760
- if (PBinary) {
761
- NativeBinaryPtr = PBinary[0 ];
750
+ size_t NumDevices = Program->AssociatedDevices .size ();
751
+ if (PropSizeRet) {
752
+ // Return the size of the array of pointers to binaries (for each device).
753
+ *PropSizeRet = NumDevices * sizeof (uint8_t *);
762
754
}
763
755
764
- size_t SzBinary = 0 ;
765
- for (uint32_t deviceIndex = 0 ;
766
- deviceIndex < Program->AssociatedDevices .size (); deviceIndex++) {
756
+ // If the caller did not provide an array of pointers to copy binaries into,
757
+ // return early.
758
+ if (!ProgramInfo)
759
+ break ;
760
+
761
+ // If the caller provided an array of pointers, copy the binaries.
762
+ uint8_t **DestBinPtrs = ur_cast<uint8_t **>(ProgramInfo);
763
+ for (uint32_t deviceIndex = 0 ; deviceIndex < NumDevices; deviceIndex++) {
764
+ uint8_t *DestBinPtr = DestBinPtrs[deviceIndex];
765
+ if (!DestBinPtr)
766
+ continue ;
767
+
767
768
auto ZeDevice = Program->AssociatedDevices [deviceIndex]->ZeDevice ;
768
769
auto State = Program->getState (ZeDevice);
769
770
if (State == ur_program_handle_t_::Native) {
770
771
// If Program was created from Native code then return that code.
771
- if (PBinary) {
772
- std::memcpy (PBinary[deviceIndex], Program->getCode (ZeDevice),
773
- Program->getCodeSize (ZeDevice));
774
- }
775
- SzBinary += Program->getCodeSize (ZeDevice);
776
- continue ;
777
- }
778
- if (State == ur_program_handle_t_::IL ||
779
- State == ur_program_handle_t_::Object) {
780
- // We don't have a binary for this device, so don't update the output
781
- // pointer to the binary, only set return size to 0.
782
- if (PropSizeRet)
783
- *PropSizeRet = 0 ;
772
+ std::memcpy (DestBinPtr, Program->getCode (ZeDevice),
773
+ Program->getCodeSize (ZeDevice));
784
774
} else if (State == ur_program_handle_t_::Exe) {
785
775
auto ZeModule = Program->getZeModuleHandle (ZeDevice);
786
776
if (!ZeModule) {
787
777
return UR_RESULT_ERROR_INVALID_PROGRAM;
788
778
}
789
- size_t binarySize = 0 ;
790
- if (PBinary) {
791
- NativeBinaryPtr = PBinary[deviceIndex];
792
- }
793
- // If the caller is using a Program which is a built binary, then
794
- // the program returned will either be a single module if this is a
795
- // native binary or the native binary for each device will be returned.
796
- ZE2UR_CALL (zeModuleGetNativeBinary,
797
- (ZeModule, &binarySize, NativeBinaryPtr));
798
- SzBinary += binarySize;
799
- } else {
800
- return UR_RESULT_ERROR_INVALID_PROGRAM;
779
+ size_t DummySize;
780
+ ZE2UR_CALL (zeModuleGetNativeBinary, (ZeModule, &DummySize, DestBinPtr));
801
781
}
802
782
}
803
- if (PropSizeRet)
804
- *PropSizeRet = SzBinary;
805
783
break ;
806
784
}
807
785
case UR_PROGRAM_INFO_NUM_KERNELS: {
0 commit comments