|
2 | 2 | // SPDX-License-Identifier: MIT |
3 | 3 |
|
4 | 4 | #include <cstring> |
| 5 | +#include <fstream> |
5 | 6 | #include <uur/environment.h> |
6 | 7 | #include <uur/utils.h> |
7 | 8 |
|
@@ -176,4 +177,147 @@ void DevicesEnvironment::TearDown() { |
176 | 177 | } |
177 | 178 | } |
178 | 179 | } |
| 180 | + |
| 181 | +KernelsEnvironment *KernelsEnvironment::instance = nullptr; |
| 182 | + |
| 183 | +KernelsEnvironment::KernelsEnvironment(int argc, char **argv, std::string kernels_default_dir) : DevicesEnvironment(argc, argv), kernel_options(parseKernelOptions(argc, argv, kernels_default_dir)) { |
| 184 | + instance = this; |
| 185 | + if (!error.empty()) { |
| 186 | + return; |
| 187 | + } |
| 188 | +} |
| 189 | + |
| 190 | +KernelsEnvironment::KernelOptions KernelsEnvironment::parseKernelOptions(int argc, char **argv, std::string kernels_default_dir) { |
| 191 | + KernelOptions options; |
| 192 | + for (int argi = 1; argi < argc; ++argi) { |
| 193 | + const char *arg = argv[argi]; |
| 194 | + if (std::strncmp(arg, "--kernel_directory=", sizeof("--kernel_directory=") - 1) == 0) { |
| 195 | + options.kernel_directory = std::string(&arg[std::strlen("--kernel_directory=")]); |
| 196 | + } |
| 197 | + } |
| 198 | + if (options.kernel_directory.empty()) { |
| 199 | + options.kernel_directory = kernels_default_dir; |
| 200 | + } |
| 201 | + |
| 202 | + return options; |
| 203 | +} |
| 204 | + |
| 205 | +std::string KernelsEnvironment::getSupportedILPostfix(uint32_t device_index) { |
| 206 | + std::stringstream IL; |
| 207 | + |
| 208 | + if (instance->GetDevices().size() == 0) { |
| 209 | + error = "no devices available on the platform"; |
| 210 | + return {}; |
| 211 | + } |
| 212 | + |
| 213 | + auto device = instance->GetDevices()[device_index]; |
| 214 | + size_t size; |
| 215 | + if (urDeviceGetInfo(device, UR_DEVICE_INFO_IL_VERSION, 0, nullptr, &size)) { |
| 216 | + error = "failed getting device IL version"; |
| 217 | + return {}; |
| 218 | + } |
| 219 | + std::string IL_version(size, '\0'); |
| 220 | + if (urDeviceGetInfo(device, UR_DEVICE_INFO_IL_VERSION, size, &IL_version[0], |
| 221 | + nullptr)) { |
| 222 | + error = "failed getting device IL version"; |
| 223 | + return {}; |
| 224 | + } |
| 225 | + |
| 226 | + // Delete the ETX character at the end as it is not part of the name. |
| 227 | + IL_version.pop_back(); |
| 228 | + |
| 229 | + IL << "_" << IL_version; |
| 230 | + |
| 231 | + // TODO: Add other IL types like ptx when they are defined how they will be |
| 232 | + // reported. |
| 233 | + if (IL_version.find("SPIR-V") != std::string::npos) { |
| 234 | + IL << ".spv"; |
| 235 | + } else { |
| 236 | + error = "Undefined IL version: " + IL_version; |
| 237 | + return {}; |
| 238 | + } |
| 239 | + |
| 240 | + return IL.str(); |
| 241 | +} |
| 242 | + |
| 243 | +std::string KernelsEnvironment::getKernelSourcePath(const std::string &kernel_name, uint32_t device_index) { |
| 244 | + std::stringstream path; |
| 245 | + path << instance->getKernelDirectory(); |
| 246 | + // il_postfix = supported_IL(SPIRV-PTX-...) + IL_version + extension(.spv - |
| 247 | + // .ptx - ....) |
| 248 | + std::string il_postfix = getSupportedILPostfix(device_index); |
| 249 | + |
| 250 | + if (il_postfix.empty()) { |
| 251 | + error = "failed getting device supported IL"; |
| 252 | + return {}; |
| 253 | + } |
| 254 | + |
| 255 | + path << "/" << kernel_name << il_postfix; |
| 256 | + |
| 257 | + uint32_t address_bits; |
| 258 | + auto device = instance->GetDevices()[device_index]; |
| 259 | + if (urDeviceGetInfo(device, UR_DEVICE_INFO_ADDRESS_BITS, sizeof(uint32_t), |
| 260 | + &address_bits, nullptr)) { |
| 261 | + error = "failed getting device address bits supported"; |
| 262 | + return {}; |
| 263 | + } |
| 264 | + path << address_bits; |
| 265 | + |
| 266 | + return path.str(); |
| 267 | +} |
| 268 | + |
| 269 | +KernelsEnvironment::KernelSource KernelsEnvironment::LoadSource(const std::string &kernel_name, uint32_t device_index) { |
| 270 | + std::string source_path = |
| 271 | + instance->getKernelSourcePath(kernel_name, device_index); |
| 272 | + |
| 273 | + if (source_path.empty()) { |
| 274 | + error = "failed retrieving kernel source path for kernel: " + kernel_name; |
| 275 | + return KernelSource{&kernel_name[0], nullptr, 0, |
| 276 | + UR_RESULT_ERROR_INVALID_BINARY}; |
| 277 | + } |
| 278 | + |
| 279 | + if (cached_kernels.find(source_path) != cached_kernels.end()) { |
| 280 | + return cached_kernels[source_path]; |
| 281 | + } |
| 282 | + |
| 283 | + std::ifstream source_file; |
| 284 | + source_file.open(source_path, std::ios::binary | std::ios::in | std::ios::ate); |
| 285 | + |
| 286 | + if (!source_file.is_open()) { |
| 287 | + error = "failed opening kernel path: " + source_path; |
| 288 | + return KernelSource{&kernel_name[0], nullptr, 0, |
| 289 | + UR_RESULT_ERROR_INVALID_BINARY}; |
| 290 | + } |
| 291 | + |
| 292 | + uint32_t source_size = static_cast<uint32_t>(source_file.tellg()); |
| 293 | + source_file.seekg(0, std::ios::beg); |
| 294 | + |
| 295 | + char *source = new char[source_size]; |
| 296 | + source_file.read(source, source_size); |
| 297 | + if (!source_file) { |
| 298 | + source_file.close(); |
| 299 | + delete[] source; |
| 300 | + error = "failed reading kernel source data from file: " + source_path; |
| 301 | + return KernelSource{&kernel_name[0], nullptr, 0, |
| 302 | + UR_RESULT_ERROR_INVALID_BINARY}; |
| 303 | + } |
| 304 | + source_file.close(); |
| 305 | + |
| 306 | + KernelSource kernel_source = |
| 307 | + KernelSource{&kernel_name[0], reinterpret_cast<uint32_t *>(source), source_size, UR_RESULT_SUCCESS}; |
| 308 | + |
| 309 | + return cached_kernels[source_path] = kernel_source; |
| 310 | +} |
| 311 | + |
| 312 | +void KernelsEnvironment::SetUp() { |
| 313 | + DevicesEnvironment::SetUp(); |
| 314 | + if (!error.empty()) { |
| 315 | + FAIL() << error; |
| 316 | + } |
| 317 | +} |
| 318 | + |
| 319 | +void KernelsEnvironment::TearDown() { |
| 320 | + cached_kernels.clear(); |
| 321 | + DevicesEnvironment::TearDown(); |
| 322 | +} |
179 | 323 | } // namespace uur |
0 commit comments