diff --git a/aiomultiprocess/core.py b/aiomultiprocess/core.py index 1cc6dea..992c499 100644 --- a/aiomultiprocess/core.py +++ b/aiomultiprocess/core.py @@ -127,7 +127,7 @@ def __await__(self) -> Any: return self.join().__await__() @staticmethod - def run_async(unit: Unit) -> R: + def run_async(unit: Unit) -> Optional[R]: """Initialize the child process and event loop, then execute the coroutine.""" try: if unit.loop_initializer is None: @@ -143,7 +143,8 @@ def run_async(unit: Unit) -> R: result: R = loop.run_until_complete(unit.target(*unit.args, **unit.kwargs)) return result - + except KeyboardInterrupt: + return None except BaseException: log.exception(f"aio process {os.getpid()} failed") raise @@ -216,13 +217,14 @@ def __init__(self, *args, **kwargs) -> None: self.unit.namespace.result = None @staticmethod - def run_async(unit: Unit) -> R: + def run_async(unit: Unit) -> Optional[R]: """Initialize the child process and event loop, then execute the coroutine.""" try: - result: R = Process.run_async(unit) + result: Optional[R] = Process.run_async(unit) unit.namespace.result = result return result - + except KeyboardInterrupt: + return None except BaseException as e: unit.namespace.result = e raise diff --git a/aiomultiprocess/tests/base.py b/aiomultiprocess/tests/base.py index ce39bbe..89b0cec 100644 --- a/aiomultiprocess/tests/base.py +++ b/aiomultiprocess/tests/base.py @@ -48,6 +48,10 @@ async def raise_fn(): raise RuntimeError("raising") +async def raise_keyboard_interrupt(): + raise KeyboardInterrupt() + + async def terminate(process): await asyncio.sleep(0.5) process.terminate() diff --git a/aiomultiprocess/tests/core.py b/aiomultiprocess/tests/core.py index a3cf988..7b56ee0 100644 --- a/aiomultiprocess/tests/core.py +++ b/aiomultiprocess/tests/core.py @@ -14,6 +14,7 @@ get_dummy_constant, initializer, raise_fn, + raise_keyboard_interrupt, sleepy, two, ) @@ -223,6 +224,13 @@ async def test_raise(self): ) self.assertIsInstance(result, RuntimeError) + @async_test + async def test_keyboard_interrupt(self): + result = await amp.Worker( + target=raise_keyboard_interrupt, name="test_process", initializer=do_nothing + ) + self.assertIsNone(result) + @async_test async def test_sync_target(self): with self.assertRaises(ValueError) as _: