@@ -299,78 +299,63 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
299
299
return Plugin::error (ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str ());
300
300
};
301
301
302
- // Find the info if it exists under any of the given names
303
- auto getInfoString =
304
- [&](std::vector<std::string> Names) -> llvm::Expected<const char *> {
305
- for (auto &Name : Names) {
306
- if (auto Entry = Device->Info .get (Name)) {
307
- if (!std::holds_alternative<std::string>((*Entry)->Value ))
308
- return makeError (ErrorCode::BACKEND_FAILURE,
309
- " plugin returned incorrect type" );
310
- return std::get<std::string>((*Entry)->Value ).c_str ();
311
- }
312
- }
313
-
314
- return makeError (ErrorCode::UNIMPLEMENTED,
315
- " plugin did not provide a response for this information" );
316
- };
317
-
318
- auto getInfoXyz =
319
- [&](std::vector<std::string> Names) -> llvm::Expected<ol_dimensions_t > {
320
- for (auto &Name : Names) {
321
- if (auto Entry = Device->Info .get (Name)) {
322
- auto Node = *Entry;
323
- ol_dimensions_t Out{0 , 0 , 0 };
324
-
325
- auto getField = [&](StringRef Name, uint32_t &Dest) {
326
- if (auto F = Node->get (Name)) {
327
- if (!std::holds_alternative<size_t >((*F)->Value ))
328
- return makeError (
329
- ErrorCode::BACKEND_FAILURE,
330
- " plugin returned incorrect type for dimensions element" );
331
- Dest = std::get<size_t >((*F)->Value );
332
- } else
333
- return makeError (ErrorCode::BACKEND_FAILURE,
334
- " plugin didn't provide all values for dimensions" );
335
- return Plugin::success ();
336
- };
337
-
338
- if (auto Res = getField (" x" , Out.x ))
339
- return Res;
340
- if (auto Res = getField (" y" , Out.y ))
341
- return Res;
342
- if (auto Res = getField (" z" , Out.z ))
343
- return Res;
344
-
345
- return Out;
346
- }
347
- }
302
+ // These are not implemented by the plugin interface
303
+ if (PropName == OL_DEVICE_INFO_PLATFORM)
304
+ return Info.write <void *>(Device->Platform );
305
+ if (PropName == OL_DEVICE_INFO_TYPE)
306
+ return Info.write <ol_device_type_t >(OL_DEVICE_TYPE_GPU);
307
+ // TODO: Update when https://github.com/llvm/llvm-project/pull/147314 is merged
308
+ if (PropName > OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE)
309
+ return createOffloadError (ErrorCode::INVALID_ENUMERATION,
310
+ " getDeviceInfo enum '%i' is invalid" , PropName);
348
311
312
+ auto EntryOpt = Device->Info .get (static_cast <DeviceInfo>(PropName));
313
+ if (!EntryOpt)
349
314
return makeError (ErrorCode::UNIMPLEMENTED,
350
315
" plugin did not provide a response for this information" );
351
- } ;
316
+ auto Entry = *EntryOpt ;
352
317
353
318
switch (PropName) {
354
- case OL_DEVICE_INFO_PLATFORM:
355
- return Info.write <void *>(Device->Platform );
356
- case OL_DEVICE_INFO_TYPE:
357
- return Info.write <ol_device_type_t >(OL_DEVICE_TYPE_GPU);
358
319
case OL_DEVICE_INFO_NAME:
359
- return Info.writeString (getInfoString ({" Device Name" }));
360
320
case OL_DEVICE_INFO_VENDOR:
361
- return Info.writeString (getInfoString ({" Vendor Name" }));
362
- case OL_DEVICE_INFO_DRIVER_VERSION:
363
- return Info.writeString (
364
- getInfoString ({" CUDA Driver Version" , " HSA Runtime Version" }));
365
- case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
366
- return Info.write (getInfoXyz ({" Workgroup Max Size per Dimension" /* AMD*/ ,
367
- " Maximum Block Dimensions" /* CUDA*/ }));
368
- default :
369
- return createOffloadError (ErrorCode::INVALID_ENUMERATION,
370
- " getDeviceInfo enum '%i' is invalid" , PropName);
321
+ case OL_DEVICE_INFO_DRIVER_VERSION: {
322
+ // String values
323
+ if (!std::holds_alternative<std::string>(Entry->Value ))
324
+ return makeError (ErrorCode::BACKEND_FAILURE,
325
+ " plugin returned incorrect type" );
326
+ return Info.writeString (std::get<std::string>(Entry->Value ).c_str ());
371
327
}
372
328
373
- return Error::success ();
329
+ case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: {
330
+ // {x, y, z} triples
331
+ ol_dimensions_t Out{0 , 0 , 0 };
332
+
333
+ auto getField = [&](StringRef Name, uint32_t &Dest) {
334
+ if (auto F = Entry->get (Name)) {
335
+ if (!std::holds_alternative<size_t >((*F)->Value ))
336
+ return makeError (
337
+ ErrorCode::BACKEND_FAILURE,
338
+ " plugin returned incorrect type for dimensions element" );
339
+ Dest = std::get<size_t >((*F)->Value );
340
+ } else
341
+ return makeError (ErrorCode::BACKEND_FAILURE,
342
+ " plugin didn't provide all values for dimensions" );
343
+ return Plugin::success ();
344
+ };
345
+
346
+ if (auto Res = getField (" x" , Out.x ))
347
+ return Res;
348
+ if (auto Res = getField (" y" , Out.y ))
349
+ return Res;
350
+ if (auto Res = getField (" z" , Out.z ))
351
+ return Res;
352
+
353
+ return Info.write (Out);
354
+ }
355
+
356
+ default :
357
+ llvm_unreachable (" Unimplemented device info" );
358
+ }
374
359
}
375
360
376
361
Error olGetDeviceInfoImplDetailHost (ol_device_handle_t Device,
0 commit comments