Skip to content

Allowing an error handler from caller #12487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions runtime/kernel/operator_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Kernel* registered_kernels = reinterpret_cast<Kernel*>(registered_kernels_data);
size_t num_registered_kernels = 0;

// Registers the kernels, but may return an error.
Error register_kernels_internal(const Span<const Kernel> kernels) {
Error register_kernels_internal(const Span<const Kernel> kernels, ErrorHandler errorHandler) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the return value not sufficient?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess what situation exists where you can pass in the errorHandler to operate on the error generated by this function, but you couldnt just operate on it after its returned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If all errors are returned, then we can just operate on it after it is returned, but there are 2 errors (RegistrationExceedingMaxKernels and RegistrationAlreadyRegistered) are not returned (see this line). If the caller wants to handle these 2 errors, the caller needs to specify the errorHandler to get a chance to handle these errors.

// Operator registration happens in static initialization time before or after
// PAL init, so call it here. It is safe to call multiple times.
::et_pal_init();
Expand All @@ -74,12 +74,19 @@ Error register_kernels_internal(const Span<const Kernel> kernels) {
ET_LOG(Error, "%s", kernels[i].name_);
ET_LOG_KERNEL_KEY(kernels[i].kernel_key_);
}

if (errorHandler != nullptr) {
return errorHandler(Error::RegistrationExceedingMaxKernels);
}

return Error::RegistrationExceedingMaxKernels;
}
// for debugging purpose
ET_UNUSED const char* lib_name =
et_pal_get_shared_library_name(kernels.data());

Error err = Error::Ok;

for (const auto& kernel : kernels) {
// Linear search. This is fine if the number of kernels is small.
for (size_t i = 0; i < num_registered_kernels; i++) {
Expand All @@ -88,24 +95,33 @@ Error register_kernels_internal(const Span<const Kernel> kernels) {
kernel.kernel_key_ == k.kernel_key_) {
ET_LOG(Error, "Re-registering %s, from %s", k.name_, lib_name);
ET_LOG_KERNEL_KEY(k.kernel_key_);
return Error::RegistrationAlreadyRegistered;
err = Error::RegistrationAlreadyRegistered;
continue;
}
}

registered_kernels[num_registered_kernels++] = kernel;
}
ET_LOG(
Debug,
"Successfully registered all kernels from shared library: %s",
lib_name);

return Error::Ok;
if (errorHandler != nullptr) {
err = errorHandler(err);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me a bit nervous. Why do we want to consume an Error and translate it to another Error? In practice, are you going to add product specific Error like LoggingFailure etc?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not against an error handler, just not sure what are we doing with the returned error here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A pattern I would definitely against is to convert Error::RegistrationAlreadyRegistered into Error::Ok if that's what you are trying to do here

}

if (err == Error::Ok) {
ET_LOG(
Debug,
"Successfully registered all kernels from shared library: %s",
lib_name);
}

return err;
}

} // namespace

// Registers the kernels, but panics if an error occurs. Always returns Ok.
Error register_kernels(const Span<const Kernel> kernels) {
Error success = register_kernels_internal(kernels);
Error register_kernels(const Span<const Kernel> kernels, ErrorHandler errorHandler) {
Error success = register_kernels_internal(kernels, errorHandler);
if (success == Error::RegistrationAlreadyRegistered ||
success == Error::RegistrationExceedingMaxKernels) {
ET_CHECK_MSG(
Expand Down
14 changes: 9 additions & 5 deletions runtime/kernel/operator_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,16 @@ ::executorch::runtime::Result<OpFunction> get_op_function_from_registry(
*/
Span<const Kernel> get_registered_kernels();

using ErrorHandler = Error (*)(Error errorCode);

/**
* Registers the provided kernels.
*
* @param[in] kernels Kernel objects to register.
* @retval Error::Ok always. Panics on error. This function needs to return a
* non-void type to run at static initialization time.
*/
ET_NODISCARD Error register_kernels(const Span<const Kernel>);
ET_NODISCARD Error register_kernels(const Span<const Kernel>, ErrorHandler errorHandler = nullptr);

/**
* Registers a single kernel.
Expand All @@ -253,8 +255,8 @@ ET_NODISCARD Error register_kernels(const Span<const Kernel>);
* @retval Error::Ok always. Panics on error. This function needs to return a
* non-void type to run at static initialization time.
*/
ET_NODISCARD inline Error register_kernel(const Kernel& kernel) {
return register_kernels({&kernel, 1});
ET_NODISCARD inline Error register_kernel(const Kernel& kernel, ErrorHandler errorHandler = nullptr) {
return register_kernels({&kernel, 1}, errorHandler);
};

} // namespace ET_RUNTIME_NAMESPACE
Expand All @@ -269,12 +271,13 @@ using ::executorch::ET_RUNTIME_NAMESPACE::KernelKey;
using ::executorch::ET_RUNTIME_NAMESPACE::KernelRuntimeContext;
using ::executorch::ET_RUNTIME_NAMESPACE::OpFunction;
using ::executorch::ET_RUNTIME_NAMESPACE::TensorMeta;
using ::executorch::ET_RUNTIME_NAMESPACE::ErrorHandler;
using KernelRuntimeContext =
::executorch::ET_RUNTIME_NAMESPACE::KernelRuntimeContext;

inline ::executorch::runtime::Error register_kernels(ArrayRef<Kernel> kernels) {
inline ::executorch::runtime::Error register_kernels(ArrayRef<Kernel> kernels, ErrorHandler errorHandler = nullptr) {
return ::executorch::ET_RUNTIME_NAMESPACE::register_kernels(
{kernels.data(), kernels.size()});
{kernels.data(), kernels.size()}, errorHandler);
}
inline OpFunction getOpsFn(
const char* name,
Expand All @@ -294,5 +297,6 @@ inline ArrayRef<Kernel> get_kernels() {
::executorch::ET_RUNTIME_NAMESPACE::get_registered_kernels();
return ArrayRef<Kernel>(kernels.data(), kernels.size());
}

} // namespace executor
} // namespace torch
Loading