Skip to content

Commit fd89ae4

Browse files
[NFCI][SYCL] Change device_image_impl::MDevices to store raw device_impl * (#19459)
#18251 extended `device_impl`s' lifetimes until shutdown and #18270 started to pass devices as raw pointers in some of the APIs. This PR builds on top of that and extends usage of raw pointers/references/`device_range` as the devices are known to be alive and extra `std::shared_ptr`'s atomic increments aren't necessary and could be avoided. Since we change the type of `device_image_impl::MDevices`, other APIs in that class and in `program_manager` don't need to operate in terms of `sycl::device` or `std::shared_ptr<device_impl>` and we can switch them to use `devices_range` instead. A small number of other modifications are caused by these APIs' changes and are necessary to keep the code buildable. One extra change is the addition of a minor `devices_range::to<std::vector<ur_device_handle_t>>()` helper that we can use now that most of the arguments are `device_range`. Technically, could go in another PR but then we'd just be modifying the exact same lines two times, so I decided to fuse it here.
1 parent efe5a5b commit fd89ae4

File tree

9 files changed

+190
-191
lines changed

9 files changed

+190
-191
lines changed

sycl/source/detail/context_impl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,13 +326,13 @@ void context_impl::removeAssociatedDeviceGlobal(const void *DeviceGlobalPtr) {
326326
}
327327

328328
void context_impl::addDeviceGlobalInitializer(
329-
ur_program_handle_t Program, const std::vector<device> &Devs,
329+
ur_program_handle_t Program, devices_range Devs,
330330
const RTDeviceBinaryImage *BinImage) {
331331
if (BinImage->getDeviceGlobals().empty())
332332
return;
333333
std::lock_guard<std::mutex> Lock(MDeviceGlobalInitializersMutex);
334-
for (const device &Dev : Devs) {
335-
auto Key = std::make_pair(Program, getSyclObjImpl(Dev)->getHandleRef());
334+
for (device_impl &Dev : Devs) {
335+
auto Key = std::make_pair(Program, Dev.getHandleRef());
336336
auto [Iter, Inserted] = MDeviceGlobalInitializers.emplace(Key, BinImage);
337337
if (Inserted && !Iter->second.MDeviceGlobalsFullyInitialized)
338338
++MDeviceGlobalNotInitializedCnt;

sycl/source/detail/context_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
216216

217217
/// Adds a device global initializer.
218218
void addDeviceGlobalInitializer(ur_program_handle_t Program,
219-
const std::vector<device> &Devs,
219+
devices_range Devs,
220220
const RTDeviceBinaryImage *BinImage);
221221

222222
/// Initializes device globals for a program on the associated queue.

sycl/source/detail/device_image_impl.hpp

Lines changed: 59 additions & 65 deletions
Large diffs are not rendered by default.

sycl/source/detail/device_impl.hpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2297,16 +2297,31 @@ struct devices_deref_impl {
22972297
using devices_iterator =
22982298
variadic_iterator<devices_deref_impl, device,
22992299
std::vector<std::shared_ptr<device_impl>>::const_iterator,
2300-
std::vector<device>::const_iterator, device_impl *>;
2300+
std::vector<device>::const_iterator,
2301+
std::vector<device_impl *>::const_iterator,
2302+
device_impl *>;
23012303

23022304
class devices_range : public iterator_range<devices_iterator> {
23032305
private:
23042306
using Base = iterator_range<devices_iterator>;
23052307

23062308
public:
23072309
using Base::Base;
2308-
devices_range(const device &Dev)
2309-
: devices_range(&*getSyclObjImpl(Dev), (&*getSyclObjImpl(Dev) + 1), 1) {}
2310+
template <typename Container>
2311+
decltype(std::declval<Base>().to<Container>()) to() const {
2312+
return this->Base::to<Container>();
2313+
}
2314+
2315+
template <typename Container>
2316+
std::enable_if_t<std::is_same_v<Container, std::vector<ur_device_handle_t>>,
2317+
Container>
2318+
to() const {
2319+
std::vector<ur_device_handle_t> DeviceHandles;
2320+
DeviceHandles.reserve(size());
2321+
std::transform(begin(), end(), std::back_inserter(DeviceHandles),
2322+
[](device_impl &Dev) { return Dev.getHandleRef(); });
2323+
return DeviceHandles;
2324+
}
23102325
};
23112326

23122327
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES

sycl/source/detail/helpers.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ template <typename iterator> class iterator_range {
110110
iterator_range(const ContainerTy &Container)
111111
: iterator_range(Container.begin(), Container.end(), Container.size()) {}
112112

113+
iterator_range(value_type &Obj) : iterator_range(&Obj, &Obj + 1, 1) {}
114+
115+
iterator_range(const sycl_type &Obj)
116+
: iterator_range(&*getSyclObjImpl(Obj), (&*getSyclObjImpl(Obj) + 1), 1) {}
117+
113118
iterator begin() const { return Begin; }
114119
iterator end() const { return End; }
115120
size_t size() const { return Size; }

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,29 @@ class kernel_bundle_impl
11381138
DeviceGlobalMap MDeviceGlobals{/*OwnerControlledCleanup=*/false};
11391139
};
11401140

1141+
inline bool is_compatible(const std::vector<kernel_id> &KernelIDs,
1142+
device_impl &Dev) {
1143+
if (KernelIDs.empty())
1144+
return true;
1145+
// One kernel may be contained in several binary images depending on the
1146+
// number of targets. This kernel is compatible with the device if there is
1147+
// at least one image (containing this kernel) whose aspects are supported by
1148+
// the device and whose target matches the device.
1149+
for (const auto &KernelID : KernelIDs) {
1150+
std::set<const detail::RTDeviceBinaryImage *> BinImages =
1151+
detail::ProgramManager::getInstance().getRawDeviceImages({KernelID});
1152+
1153+
if (std::none_of(BinImages.begin(), BinImages.end(),
1154+
[&](const detail::RTDeviceBinaryImage *Img) {
1155+
return doesDevSupportDeviceRequirements(Dev, *Img) &&
1156+
doesImageTargetMatchDevice(*Img, Dev);
1157+
}))
1158+
return false;
1159+
}
1160+
1161+
return true;
1162+
}
1163+
11411164
} // namespace detail
11421165
} // namespace _V1
11431166
} // namespace sycl

0 commit comments

Comments
 (0)