diff --git a/runtime/kernel/operator_registry.cpp b/runtime/kernel/operator_registry.cpp index 3738f8285af..b0ecbe65fc6 100644 --- a/runtime/kernel/operator_registry.cpp +++ b/runtime/kernel/operator_registry.cpp @@ -49,7 +49,7 @@ Kernel* registered_kernels = reinterpret_cast(registered_kernels_data); size_t num_registered_kernels = 0; // Registers the kernels, but may return an error. -Error register_kernels_internal(const Span kernels) { +Error register_kernels_internal(const Span kernels, ErrorHandler errorHandler) { // Operator registration happens in static initialization time before or after // PAL init, so call it here. It is safe to call multiple times. ::et_pal_init(); @@ -74,12 +74,19 @@ Error register_kernels_internal(const Span kernels) { ET_LOG(Error, "%s", kernels[i].name_); ET_LOG_KERNEL_KEY(kernels[i].kernel_key_); } + + if (errorHandler != nullptr) { + return errorHandler(Error::RegistrationExceedingMaxKernels); + } + return Error::RegistrationExceedingMaxKernels; } // for debugging purpose ET_UNUSED const char* lib_name = et_pal_get_shared_library_name(kernels.data()); + Error err = Error::Ok; + for (const auto& kernel : kernels) { // Linear search. This is fine if the number of kernels is small. for (size_t i = 0; i < num_registered_kernels; i++) { @@ -88,24 +95,33 @@ Error register_kernels_internal(const Span kernels) { kernel.kernel_key_ == k.kernel_key_) { ET_LOG(Error, "Re-registering %s, from %s", k.name_, lib_name); ET_LOG_KERNEL_KEY(k.kernel_key_); - return Error::RegistrationAlreadyRegistered; + err = Error::RegistrationAlreadyRegistered; + continue; } } + registered_kernels[num_registered_kernels++] = kernel; } - ET_LOG( - Debug, - "Successfully registered all kernels from shared library: %s", - lib_name); - return Error::Ok; + if (errorHandler != nullptr) { + err = errorHandler(err); + } + + if (err == Error::Ok) { + ET_LOG( + Debug, + "Successfully registered all kernels from shared library: %s", + lib_name); + } + + return err; } } // namespace // Registers the kernels, but panics if an error occurs. Always returns Ok. -Error register_kernels(const Span kernels) { - Error success = register_kernels_internal(kernels); +Error register_kernels(const Span kernels, ErrorHandler errorHandler) { + Error success = register_kernels_internal(kernels, errorHandler); if (success == Error::RegistrationAlreadyRegistered || success == Error::RegistrationExceedingMaxKernels) { ET_CHECK_MSG( diff --git a/runtime/kernel/operator_registry.h b/runtime/kernel/operator_registry.h index 9bd6318676c..f8b0b036a53 100644 --- a/runtime/kernel/operator_registry.h +++ b/runtime/kernel/operator_registry.h @@ -237,6 +237,8 @@ ::executorch::runtime::Result get_op_function_from_registry( */ Span get_registered_kernels(); +using ErrorHandler = Error (*)(Error errorCode); + /** * Registers the provided kernels. * @@ -244,7 +246,7 @@ Span get_registered_kernels(); * @retval Error::Ok always. Panics on error. This function needs to return a * non-void type to run at static initialization time. */ -ET_NODISCARD Error register_kernels(const Span); +ET_NODISCARD Error register_kernels(const Span, ErrorHandler errorHandler = nullptr); /** * Registers a single kernel. @@ -253,8 +255,8 @@ ET_NODISCARD Error register_kernels(const Span); * @retval Error::Ok always. Panics on error. This function needs to return a * non-void type to run at static initialization time. */ -ET_NODISCARD inline Error register_kernel(const Kernel& kernel) { - return register_kernels({&kernel, 1}); +ET_NODISCARD inline Error register_kernel(const Kernel& kernel, ErrorHandler errorHandler = nullptr) { + return register_kernels({&kernel, 1}, errorHandler); }; } // namespace ET_RUNTIME_NAMESPACE @@ -269,12 +271,13 @@ using ::executorch::ET_RUNTIME_NAMESPACE::KernelKey; using ::executorch::ET_RUNTIME_NAMESPACE::KernelRuntimeContext; using ::executorch::ET_RUNTIME_NAMESPACE::OpFunction; using ::executorch::ET_RUNTIME_NAMESPACE::TensorMeta; +using ::executorch::ET_RUNTIME_NAMESPACE::ErrorHandler; using KernelRuntimeContext = ::executorch::ET_RUNTIME_NAMESPACE::KernelRuntimeContext; -inline ::executorch::runtime::Error register_kernels(ArrayRef kernels) { +inline ::executorch::runtime::Error register_kernels(ArrayRef kernels, ErrorHandler errorHandler = nullptr) { return ::executorch::ET_RUNTIME_NAMESPACE::register_kernels( - {kernels.data(), kernels.size()}); + {kernels.data(), kernels.size()}, errorHandler); } inline OpFunction getOpsFn( const char* name, @@ -294,5 +297,6 @@ inline ArrayRef get_kernels() { ::executorch::ET_RUNTIME_NAMESPACE::get_registered_kernels(); return ArrayRef(kernels.data(), kernels.size()); } + } // namespace executor } // namespace torch