@@ -46,18 +46,36 @@ ur_result_t MsanInterceptor::allocateMemory(ur_context_handle_t Context,
4646 ur_device_handle_t Device,
4747 const ur_usm_desc_t *Properties,
4848 ur_usm_pool_handle_t Pool,
49- size_t Size, void **ResultPtr) {
49+ size_t Size, AllocType Type,
50+ void **ResultPtr) {
5051
5152 auto ContextInfo = getContextInfo (Context);
52- std::shared_ptr<DeviceInfo> DeviceInfo = getDeviceInfo (Device);
53+ std::shared_ptr<DeviceInfo> DeviceInfo =
54+ Device ? getDeviceInfo (Device) : nullptr ;
5355
5456 void *Allocated = nullptr ;
5557
56- UR_CALL (getContext ()->urDdiTable .USM .pfnDeviceAlloc (
57- Context, Device, Properties, Pool, Size, &Allocated));
58+ if (Type == AllocType::DEVICE_USM) {
59+ UR_CALL (getContext ()->urDdiTable .USM .pfnDeviceAlloc (
60+ Context, Device, Properties, Pool, Size, &Allocated));
61+ } else if (Type == AllocType::HOST_USM) {
62+ UR_CALL (getContext ()->urDdiTable .USM .pfnHostAlloc (
63+ Context, Properties, Pool, Size, &Allocated));
64+ } else if (Type == AllocType::SHARED_USM) {
65+ UR_CALL (getContext ()->urDdiTable .USM .pfnSharedAlloc (
66+ Context, Device, Properties, Pool, Size, &Allocated));
67+ }
5868
5969 *ResultPtr = Allocated;
6070
71+ ContextInfo->MaxAllocatedSize =
72+ std::max (ContextInfo->MaxAllocatedSize , Size);
73+
74+ // For host/shared usm, we only record the alloc size.
75+ if (Type != AllocType::DEVICE_USM) {
76+ return UR_RESULT_SUCCESS;
77+ }
78+
6179 auto AI =
6280 std::make_shared<MsanAllocInfo>(MsanAllocInfo{(uptr)Allocated,
6381 Size,
@@ -145,6 +163,12 @@ ur_result_t MsanInterceptor::registerProgram(ur_program_handle_t Program) {
145163 return Result;
146164 }
147165
166+ getContext ()->logger .info (" registerDeviceGlobals" );
167+ Result = registerDeviceGlobals (Program);
168+ if (Result != UR_RESULT_SUCCESS) {
169+ return Result;
170+ }
171+
148172 return Result;
149173}
150174
@@ -213,6 +237,56 @@ ur_result_t MsanInterceptor::registerSpirKernels(ur_program_handle_t Program) {
213237 return UR_RESULT_SUCCESS;
214238}
215239
240+ ur_result_t
241+ MsanInterceptor::registerDeviceGlobals (ur_program_handle_t Program) {
242+ std::vector<ur_device_handle_t > Devices = GetDevices (Program);
243+ assert (Devices.size () != 0 && " No devices in registerDeviceGlobals" );
244+ auto Context = GetContext (Program);
245+ auto ContextInfo = getContextInfo (Context);
246+ auto ProgramInfo = getProgramInfo (Program);
247+ assert (ProgramInfo != nullptr && " unregistered program!" );
248+
249+ for (auto Device : Devices) {
250+ ManagedQueue Queue (Context, Device);
251+
252+ size_t MetadataSize;
253+ void *MetadataPtr;
254+ auto Result =
255+ getContext ()->urDdiTable .Program .pfnGetGlobalVariablePointer (
256+ Device, Program, kSPIR_MsanDeviceGlobalMetadata , &MetadataSize,
257+ &MetadataPtr);
258+ if (Result != UR_RESULT_SUCCESS) {
259+ getContext ()->logger .info (" No device globals" );
260+ continue ;
261+ }
262+
263+ const uint64_t NumOfDeviceGlobal =
264+ MetadataSize / sizeof (DeviceGlobalInfo);
265+ assert ((MetadataSize % sizeof (DeviceGlobalInfo) == 0 ) &&
266+ " DeviceGlobal metadata size is not correct" );
267+ std::vector<DeviceGlobalInfo> GVInfos (NumOfDeviceGlobal);
268+ Result = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
269+ Queue, true , &GVInfos[0 ], MetadataPtr,
270+ sizeof (DeviceGlobalInfo) * NumOfDeviceGlobal, 0 , nullptr , nullptr );
271+ if (Result != UR_RESULT_SUCCESS) {
272+ getContext ()->logger .error (" Device Global[{}] Read Failed: {}" ,
273+ kSPIR_MsanDeviceGlobalMetadata , Result);
274+ return Result;
275+ }
276+
277+ auto DeviceInfo = getMsanInterceptor ()->getDeviceInfo (Device);
278+ for (size_t i = 0 ; i < NumOfDeviceGlobal; i++) {
279+ const auto &GVInfo = GVInfos[i];
280+ UR_CALL (DeviceInfo->Shadow ->EnqueuePoisonShadow (Queue, GVInfo.Addr ,
281+ GVInfo.Size , 0 ));
282+ ContextInfo->MaxAllocatedSize =
283+ std::max (ContextInfo->MaxAllocatedSize , GVInfo.Size );
284+ }
285+ }
286+
287+ return UR_RESULT_SUCCESS;
288+ }
289+
216290ur_result_t MsanInterceptor::insertContext (ur_context_handle_t Context,
217291 std::shared_ptr<ContextInfo> &CI) {
218292 std::scoped_lock<ur_shared_mutex> Guard (m_ContextMapMutex);
@@ -380,10 +454,14 @@ ur_result_t MsanInterceptor::prepareLaunch(
380454 }
381455
382456 // Set LaunchInfo
457+ auto ContextInfo = getContextInfo (LaunchInfo.Context );
383458 LaunchInfo.Data ->GlobalShadowOffset = DeviceInfo->Shadow ->ShadowBegin ;
384459 LaunchInfo.Data ->GlobalShadowOffsetEnd = DeviceInfo->Shadow ->ShadowEnd ;
385460 LaunchInfo.Data ->DeviceTy = DeviceInfo->Type ;
386461 LaunchInfo.Data ->Debug = getOptions ().Debug ? 1 : 0 ;
462+ UR_CALL (getContext ()->urDdiTable .USM .pfnDeviceAlloc (
463+ ContextInfo->Handle , DeviceInfo->Handle , nullptr , nullptr ,
464+ ContextInfo->MaxAllocatedSize , &LaunchInfo.Data ->CleanShadow ));
387465
388466 getContext ()->logger .info (
389467 " launch_info {} (GlobalShadow={}, Device={}, Debug={})" ,
@@ -466,6 +544,11 @@ ur_result_t USMLaunchInfo::initialize() {
466544USMLaunchInfo::~USMLaunchInfo () {
467545 [[maybe_unused]] ur_result_t Result;
468546 if (Data) {
547+ if (Data->CleanShadow ) {
548+ Result = getContext ()->urDdiTable .USM .pfnFree (Context,
549+ Data->CleanShadow );
550+ assert (Result == UR_RESULT_SUCCESS);
551+ }
469552 Result = getContext ()->urDdiTable .USM .pfnFree (Context, (void *)Data);
470553 assert (Result == UR_RESULT_SUCCESS);
471554 }
0 commit comments