Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions pynumaflow/mapper/_servicer/_sync_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ def __init__(self, handler: MapSyncCallable, multiproc: bool = False):
self.multiproc = multiproc
# create a thread pool for executing UDF code
self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT)
# Thread-safe event to track shutdown state and prevent race conditions
self._shutdown_event = threading.Event()
self._shutdown_lock = threading.Lock() # NEW: lock for shutdown/error handling

def _handle_error(self, context, error):
"""
Ensures only one thread triggers shutdown and error reporting.
"""
with self._shutdown_lock:
if not self._shutdown_event.is_set():
self._shutdown_event.set()
exit_on_error(
context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(error)}", parent=self.multiproc
)
else:
_LOGGER.info("Shutdown already initiated by another thread, exiting quietly")

def MapFn(
self,
Expand Down Expand Up @@ -56,10 +72,7 @@ def MapFn(
for res in result_queue.read_iterator():
# if error handler accordingly
if isinstance(res, BaseException):
# Terminate the current server process due to exception
exit_on_error(
context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}", parent=self.multiproc
)
self._handle_error(context, res)
return
# return the result
yield res
Expand All @@ -70,10 +83,7 @@ def MapFn(

except BaseException as err:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
# Terminate the current server process due to exception
exit_on_error(
context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}", parent=self.multiproc
)
self._handle_error(context, err)
return

def _process_requests(
Expand All @@ -86,6 +96,10 @@ def _process_requests(
# read through all incoming requests and submit to the
# threadpool for invocation
for request in request_iterator:
# Check if shutdown has been initiated before submitting new tasks
if self._shutdown_event.is_set():
_LOGGER.info("Shutdown initiated, stopping request processing")
break
_ = self.executor.submit(self._invoke_map, context, request, result_queue)
# wait for all tasks to finish after all requests exhausted
self.executor.shutdown(wait=True)
Expand All @@ -101,6 +115,11 @@ def _invoke_map(
result_queue: SyncIterator,
):
try:
# Check if shutdown has been initiated before processing
if self._shutdown_event.is_set():
_LOGGER.info("Shutdown initiated, skipping map invocation")
return

d = Datum(
keys=list(request.request.keys),
value=request.request.value,
Expand All @@ -123,7 +142,11 @@ def _invoke_map(

except BaseException as e:
_LOGGER.critical("MapFn handler error", exc_info=True)
result_queue.put(e)
# Only put the exception in the queue if shutdown hasn't been initiated
if not self._shutdown_event.is_set():
result_queue.put(e)
else:
_LOGGER.info("Shutdown already initiated, not queuing additional error")
return

def IsReady(
Expand Down