1414#include " ../sampler.hpp"
1515#include " ../ur_interface_loader.hpp"
1616#include " command_buffer.hpp"
17+ #include " common.hpp"
1718#include " context.hpp"
1819#include " kernel.hpp"
1920#include " memory.hpp"
@@ -149,21 +150,13 @@ ur_command_list_manager::getSignalEvent(ur_event_handle_t hUserEvent,
149150 }
150151}
151152
152- ur_result_t ur_command_list_manager::appendKernelLaunchUnlocked (
153- ur_kernel_handle_t hKernel, uint32_t workDim,
153+ // must be called with hKernel->Mutex held
154+ ur_result_t ur_command_list_manager::appendKernelLaunchLocked (
155+ ur_kernel_handle_t hKernel, ze_kernel_handle_t hZeKernel, uint32_t workDim,
154156 const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
155157 const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
156158 const ur_event_handle_t *phEventWaitList, ur_event_handle_t phEvent,
157- bool cooperative) {
158- UR_ASSERT (hKernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
159- UR_ASSERT (hKernel->getProgramHandle (), UR_RESULT_ERROR_INVALID_NULL_POINTER);
160-
161- UR_ASSERT (workDim > 0 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
162- UR_ASSERT (workDim < 4 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
163-
164- ze_kernel_handle_t hZeKernel = hKernel->getZeHandle (hDevice.get ());
165-
166- std::scoped_lock<ur_shared_mutex> Lock (hKernel->Mutex );
159+ bool cooperative, std::vector<void *> &kMemObj , void *pNext) {
167160
168161 ze_group_count_t zeThreadGroupDimensions{1 , 1 , 1 };
169162 uint32_t WG[3 ]{};
@@ -176,15 +169,28 @@ ur_result_t ur_command_list_manager::appendKernelLaunchUnlocked(
176169
177170 UR_CALL (hKernel->prepareForSubmission (
178171 hContext.get (), hDevice.get (), pGlobalWorkOffset, workDim, WG[0 ], WG[1 ],
179- WG[2 ], getZeCommandList (), waitListView));
172+ WG[2 ], getZeCommandList (), waitListView, kMemObj ));
180173
181- if (cooperative) {
174+ if (!kMemObj .empty ()) {
175+ // zeCommandListAppendLaunchKernelWithArguments
176+ TRACK_SCOPE_LATENCY (" ur_command_list_manager::"
177+ " zeCommandListAppendLaunchKernelWithArguments" );
178+ ze_group_size_t groupSize = {WG[0 ], WG[1 ], WG[2 ]};
179+ ZE2UR_CALL (hContext->getPlatform ()
180+ ->ZeCommandListAppendLaunchKernelWithArgumentsExt
181+ .zeCommandListAppendLaunchKernelWithArguments ,
182+ (getZeCommandList (), hZeKernel, zeThreadGroupDimensions,
183+ groupSize, hKernel->kernelArgs .data (), pNext, zeSignalEvent,
184+ waitListView.num , waitListView.handles ));
185+ } else if (cooperative) {
186+ // zeCommandListAppendLaunchCooperativeKernel
182187 TRACK_SCOPE_LATENCY (" ur_command_list_manager::"
183188 " zeCommandListAppendLaunchCooperativeKernel" );
184189 ZE2UR_CALL (zeCommandListAppendLaunchCooperativeKernel,
185190 (getZeCommandList (), hZeKernel, &zeThreadGroupDimensions,
186191 zeSignalEvent, waitListView.num , waitListView.handles ));
187192 } else {
193+ // zeCommandListAppendLaunchKernel
188194 TRACK_SCOPE_LATENCY (" ur_command_list_manager::"
189195 " zeCommandListAppendLaunchKernel" );
190196 ZE2UR_CALL (zeCommandListAppendLaunchKernel,
@@ -199,6 +205,39 @@ ur_result_t ur_command_list_manager::appendKernelLaunchUnlocked(
199205 return UR_RESULT_SUCCESS;
200206}
201207
208+ static ur_result_t kernelLaunchChecks (ur_kernel_handle_t hKernel,
209+ uint32_t workDim) {
210+ UR_ASSERT (hKernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
211+ UR_ASSERT (hKernel->getProgramHandle (), UR_RESULT_ERROR_INVALID_NULL_POINTER);
212+ UR_ASSERT (workDim > 0 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
213+ UR_ASSERT (workDim < 4 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
214+
215+ return UR_RESULT_SUCCESS;
216+ }
217+
218+ ur_result_t ur_command_list_manager::appendKernelLaunchUnlocked (
219+ ur_kernel_handle_t hKernel, uint32_t workDim,
220+ const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
221+ const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
222+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t phEvent,
223+ bool cooperative) {
224+
225+ ur_result_t checkResult = kernelLaunchChecks (hKernel, workDim);
226+ if (checkResult != UR_RESULT_SUCCESS) {
227+ return checkResult;
228+ }
229+
230+ ze_kernel_handle_t hZeKernel = hKernel->getZeHandle (hDevice.get ());
231+ std::vector<void *> emptyKMemObj;
232+
233+ std::scoped_lock<ur_shared_mutex> Lock (hKernel->Mutex );
234+
235+ return appendKernelLaunchLocked (
236+ hKernel, hZeKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
237+ pLocalWorkSize, numEventsInWaitList, phEventWaitList, phEvent,
238+ cooperative, emptyKMemObj, nullptr /* pNext */ );
239+ }
240+
202241ur_result_t ur_command_list_manager::appendKernelLaunch (
203242 ur_kernel_handle_t hKernel, uint32_t workDim,
204243 const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
@@ -1039,7 +1078,7 @@ ur_result_t ur_command_list_manager::releaseSubmittedKernels() {
10391078 return UR_RESULT_SUCCESS;
10401079}
10411080
1042- ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExp (
1081+ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExpOld (
10431082 ur_kernel_handle_t hKernel, uint32_t workDim,
10441083 const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
10451084 const size_t *pLocalWorkSize, uint32_t numArgs,
@@ -1048,8 +1087,6 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExp(
10481087 const ur_kernel_launch_property_t *launchPropList,
10491088 uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
10501089 ur_event_handle_t phEvent) {
1051- TRACK_SCOPE_LATENCY (
1052- " ur_queue_immediate_in_order_t::enqueueKernelLaunchWithArgsExp" );
10531090 {
10541091 std::scoped_lock<ur_shared_mutex> guard (hKernel->Mutex );
10551092 for (uint32_t argIndex = 0 ; argIndex < numArgs; argIndex++) {
@@ -1091,7 +1128,129 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExp(
10911128 numPropsInLaunchPropList, launchPropList,
10921129 numEventsInWaitList, phEventWaitList, phEvent));
10931130
1094- recordSubmittedKernel (hKernel);
1131+ return UR_RESULT_SUCCESS;
1132+ }
1133+
1134+ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExpNew (
1135+ ur_kernel_handle_t hKernel, uint32_t workDim,
1136+ const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
1137+ const size_t *pLocalWorkSize, uint32_t numArgs,
1138+ const ur_exp_kernel_arg_properties_t *pArgs, uint32_t numEventsInWaitList,
1139+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t phEvent,
1140+ bool cooperativeKernelLaunchRequested) {
1141+
1142+ ur_result_t checkResult = kernelLaunchChecks (hKernel, workDim);
1143+ if (checkResult != UR_RESULT_SUCCESS) {
1144+ return checkResult;
1145+ }
1146+
1147+ // It is needed in case of UR_KERNEL_LAUNCH_PROPERTY_ID_COOPERATIVE
1148+ // to launch the cooperative kernel.
1149+ ZeStruct<ze_command_list_append_launch_kernel_param_cooperative_desc_t >
1150+ cooperativeDesc;
1151+ cooperativeDesc.isCooperative = static_cast <ze_bool_t >(true );
1152+
1153+ void *pNext = nullptr ;
1154+ if (cooperativeKernelLaunchRequested) {
1155+ pNext = &cooperativeDesc;
1156+ }
1157+
1158+ ze_kernel_handle_t hZeKernel = hKernel->getZeHandle (hDevice.get ());
1159+
1160+ std::scoped_lock<ur_shared_mutex> Lock (hKernel->Mutex );
1161+
1162+ // kernelMemObj contains kernel memory objects that
1163+ // UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ kernelArgs pointers point to
1164+ hKernel->kernelMemObj .resize (numArgs, 0 );
1165+ hKernel->kernelArgs .resize (numArgs, 0 );
1166+
1167+ for (uint32_t argIndex = 0 ; argIndex < numArgs; argIndex++) {
1168+ switch (pArgs[argIndex].type ) {
1169+ case UR_EXP_KERNEL_ARG_TYPE_LOCAL:
1170+ hKernel->kernelArgs [argIndex] = (void *)&pArgs[argIndex].size ;
1171+ break ;
1172+ case UR_EXP_KERNEL_ARG_TYPE_VALUE:
1173+ hKernel->kernelArgs [argIndex] = (void *)pArgs[argIndex].value .value ;
1174+ break ;
1175+ case UR_EXP_KERNEL_ARG_TYPE_POINTER:
1176+ hKernel->kernelArgs [argIndex] = (void *)&pArgs[argIndex].value .pointer ;
1177+ break ;
1178+ case UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ:
1179+ // prepareForSubmission() will save zePtr in kernelMemObj[argIndex]
1180+ hKernel->kernelArgs [argIndex] = &hKernel->kernelMemObj [argIndex];
1181+ UR_CALL (hKernel->addPendingMemoryAllocation (
1182+ {pArgs[argIndex].value .memObjTuple .hMem ,
1183+ ur_mem_buffer_t ::device_access_mode_t ::read_write,
1184+ pArgs[argIndex].index }));
1185+ break ;
1186+ case UR_EXP_KERNEL_ARG_TYPE_SAMPLER:
1187+ hKernel->kernelArgs [argIndex] = &pArgs[argIndex].value .sampler ->ZeSampler ;
1188+ break ;
1189+ default :
1190+ return UR_RESULT_ERROR_INVALID_ENUMERATION;
1191+ }
1192+ }
1193+
1194+ return appendKernelLaunchLocked (
1195+ hKernel, hZeKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
1196+ pLocalWorkSize, numEventsInWaitList, phEventWaitList, phEvent,
1197+ cooperativeKernelLaunchRequested, hKernel->kernelMemObj , pNext);
1198+ }
1199+
1200+ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExp (
1201+ ur_kernel_handle_t hKernel, uint32_t workDim,
1202+ const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
1203+ const size_t *pLocalWorkSize, uint32_t numArgs,
1204+ const ur_exp_kernel_arg_properties_t *pArgs,
1205+ uint32_t numPropsInLaunchPropList,
1206+ const ur_kernel_launch_property_t *launchPropList,
1207+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
1208+ ur_event_handle_t phEvent) {
1209+ TRACK_SCOPE_LATENCY (
1210+ " ur_queue_immediate_in_order_t::enqueueKernelLaunchWithArgsExp" );
1211+
1212+ bool cooperativeKernelLaunchRequested = false ;
1213+
1214+ for (uint32_t propIndex = 0 ; propIndex < numPropsInLaunchPropList;
1215+ propIndex++) {
1216+ switch (launchPropList[propIndex].id ) {
1217+ case UR_KERNEL_LAUNCH_PROPERTY_ID_IGNORE:
1218+ break ;
1219+ case UR_KERNEL_LAUNCH_PROPERTY_ID_COOPERATIVE:
1220+ if (launchPropList[propIndex].value .cooperative ) {
1221+ cooperativeKernelLaunchRequested = true ;
1222+ }
1223+ break ;
1224+ default :
1225+ // We don't support any other properties.
1226+ return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
1227+ }
1228+ }
1229+
1230+ ur_platform_handle_t hPlatform = hContext->getPlatform ();
1231+ bool KernelWithArgsSupported =
1232+ hPlatform->ZeCommandListAppendLaunchKernelWithArgumentsExt .Supported ;
1233+ bool CooperativeCompatible =
1234+ hPlatform->ZeCommandListAppendLaunchKernelWithArgumentsExt
1235+ .DriverSupportsCooperativeKernelLaunchWithArgs ;
1236+ bool RunNewPath =
1237+ KernelWithArgsSupported &&
1238+ (!cooperativeKernelLaunchRequested ||
1239+ (cooperativeKernelLaunchRequested && CooperativeCompatible));
1240+ if (RunNewPath) {
1241+ return appendKernelLaunchWithArgsExpNew (
1242+ hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
1243+ numArgs, pArgs, numEventsInWaitList, phEventWaitList, phEvent,
1244+ cooperativeKernelLaunchRequested);
1245+ } else {
1246+ // We cannot pass cooperativeKernelLaunchRequested to
1247+ // appendKernelLaunchWithArgsExpOld() because appendKernelLaunch() must
1248+ // check it on its own since it is called also from enqueueKernelLaunch().
1249+ return appendKernelLaunchWithArgsExpOld (
1250+ hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
1251+ numArgs, pArgs, numPropsInLaunchPropList, launchPropList,
1252+ numEventsInWaitList, phEventWaitList, phEvent);
1253+ }
10951254
10961255 return UR_RESULT_SUCCESS;
10971256}
0 commit comments