From dd0738dc68fa3c8828e7a3c115a491859ef13d5a Mon Sep 17 00:00:00 2001 From: GiperBoreipy Date: Fri, 27 Jun 2025 13:06:41 +0300 Subject: [PATCH 1/5] feat: support for handling custom exceptions in middleware. --- taskiq/middlewares/simple_retry_middleware.py | 9 ++++++++- taskiq/middlewares/smart_retry_middleware.py | 10 +++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/taskiq/middlewares/simple_retry_middleware.py b/taskiq/middlewares/simple_retry_middleware.py index 0eaee1bc..a1a4ec26 100644 --- a/taskiq/middlewares/simple_retry_middleware.py +++ b/taskiq/middlewares/simple_retry_middleware.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Any +from typing import Any, Iterable, Optional from taskiq.abc.middleware import TaskiqMiddleware from taskiq.exceptions import NoResultError @@ -18,10 +18,12 @@ def __init__( default_retry_count: int = 3, default_retry_label: bool = False, no_result_on_retry: bool = True, + types_of_exceptions: Optional[Iterable[type[BaseException]]] = None, ) -> None: self.default_retry_count = default_retry_count self.default_retry_label = default_retry_label self.no_result_on_retry = no_result_on_retry + self.types_of_exceptions = types_of_exceptions async def on_error( self, @@ -42,6 +44,11 @@ async def on_error( :param result: execution result. :param exception: found exception. """ + if self.types_of_exceptions is not None and not isinstance( + exception, tuple(self.types_of_exceptions) + ): + return + # Valid exception if isinstance(exception, NoResultError): return diff --git a/taskiq/middlewares/smart_retry_middleware.py b/taskiq/middlewares/smart_retry_middleware.py index ef0bcb63..ea009e86 100644 --- a/taskiq/middlewares/smart_retry_middleware.py +++ b/taskiq/middlewares/smart_retry_middleware.py @@ -1,7 +1,7 @@ import datetime import random from logging import getLogger -from typing import Any, Optional +from typing import Any, Optional, Iterable from taskiq import ScheduleSource from taskiq.abc.middleware import TaskiqMiddleware @@ -35,6 +35,7 @@ def __init__( use_delay_exponent: bool = False, max_delay_exponent: float = 60, schedule_source: Optional[ScheduleSource] = None, + types_of_exceptions: Optional[Iterable[type[BaseException]]] = None, ) -> None: """ Initialize retry middleware. @@ -48,6 +49,7 @@ def __init__( :param max_delay_exponent: Maximum allowed delay when using backoff. :param schedule_source: Schedule source to use for scheduling. If None, the default broker will be used. + :param types_of_exceptions: Types of exceptions to retry from. """ super().__init__() self.default_retry_count = default_retry_count @@ -58,6 +60,7 @@ def __init__( self.use_delay_exponent = use_delay_exponent self.max_delay_exponent = max_delay_exponent self.schedule_source = schedule_source + self.types_of_exceptions = types_of_exceptions if not isinstance(schedule_source, (ScheduleSource, type(None))): raise TypeError( @@ -138,6 +141,11 @@ async def on_error( :param result: Execution result. :param exception: Caught exception. """ + if self.types_of_exceptions is not None and not isinstance( + exception, tuple(self.types_of_exceptions) + ): + return + if isinstance(exception, NoResultError): return From 21d313c72095ba79c2b9d92ddfbd0e04371e4c08 Mon Sep 17 00:00:00 2001 From: GiperBoreipy Date: Fri, 27 Jun 2025 13:34:16 +0300 Subject: [PATCH 2/5] feat: add tests to handling custom exc types --- tests/middlewares/test_task_retry.py | 108 ++++++++++++++++++++++++++- 1 file changed, 107 insertions(+), 1 deletion(-) diff --git a/tests/middlewares/test_task_retry.py b/tests/middlewares/test_task_retry.py index 7544ab66..51ea9419 100644 --- a/tests/middlewares/test_task_retry.py +++ b/tests/middlewares/test_task_retry.py @@ -2,7 +2,7 @@ import pytest -from taskiq import InMemoryBroker, SimpleRetryMiddleware +from taskiq import InMemoryBroker, SimpleRetryMiddleware, SmartRetryMiddleware from taskiq.exceptions import NoResultError @@ -151,3 +151,109 @@ def run_task() -> str: assert runs == 1 assert str(resp.error) == str(runs) + + +@pytest.mark.anyio +async def test_retry_of_custom_exc_types_of_simple_middleware() -> None: + # test that the passed error will be handled + broker = InMemoryBroker().with_middlewares( + SimpleRetryMiddleware( + no_result_on_retry=True, + default_retry_label=True, + types_of_exceptions=(KeyError, ValueError), + ) + ) + runs = 0 + + @broker.task(max_retries=10) + def run_task() -> None: + nonlocal runs + + runs += 1 + + raise ValueError() + + task = await run_task.kiq() + resp = await task.wait_result(timeout=1) + with pytest.raises(ValueError): + resp.raise_for_error() + + assert runs == 10 + + # test that an untransmitted error will not be handled + broker = InMemoryBroker().with_middlewares( + SimpleRetryMiddleware( + no_result_on_retry=True, + default_retry_label=True, + types_of_exceptions=(KeyError,), + ) + ) + runs = 0 + + @broker.task(max_retries=10) + def run_task2() -> None: + nonlocal runs + + runs += 1 + + raise ValueError() + + task = await run_task2.kiq() + resp = await task.wait_result(timeout=1) + with pytest.raises(ValueError): + resp.raise_for_error() + + assert runs == 1 + + +@pytest.mark.anyio +async def test_retry_of_custom_exc_types_of_smart_middleware() -> None: + # test that the passed error will be handled + broker = InMemoryBroker().with_middlewares( + SmartRetryMiddleware( + no_result_on_retry=True, + default_retry_label=True, + types_of_exceptions=(KeyError, ValueError), + ) + ) + runs = 0 + + @broker.task(max_retries=10) + def run_task() -> None: + nonlocal runs + + runs += 1 + + raise ValueError() + + task = await run_task.kiq() + resp = await task.wait_result(timeout=1) + with pytest.raises(ValueError): + resp.raise_for_error() + + assert runs == 10 + + # test that an untransmitted error will not be handled + broker = InMemoryBroker().with_middlewares( + SmartRetryMiddleware( + no_result_on_retry=True, + default_retry_label=True, + types_of_exceptions=(KeyError,), + ) + ) + runs = 0 + + @broker.task(max_retries=10) + def run_task2() -> None: + nonlocal runs + + runs += 1 + + raise ValueError() + + task = await run_task2.kiq() + resp = await task.wait_result(timeout=1) + with pytest.raises(ValueError): + resp.raise_for_error() + + assert runs == 1 From ccff3d1f8323d06f4b9e5c65b3c07037cc446062 Mon Sep 17 00:00:00 2001 From: GiperBoreipy Date: Tue, 1 Jul 2025 19:11:03 +0300 Subject: [PATCH 3/5] fix: ruff mypy lint --- tests/middlewares/test_task_retry.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/middlewares/test_task_retry.py b/tests/middlewares/test_task_retry.py index 51ea9419..8d5dfca9 100644 --- a/tests/middlewares/test_task_retry.py +++ b/tests/middlewares/test_task_retry.py @@ -161,7 +161,7 @@ async def test_retry_of_custom_exc_types_of_simple_middleware() -> None: no_result_on_retry=True, default_retry_label=True, types_of_exceptions=(KeyError, ValueError), - ) + ), ) runs = 0 @@ -171,7 +171,7 @@ def run_task() -> None: runs += 1 - raise ValueError() + raise ValueError(runs) task = await run_task.kiq() resp = await task.wait_result(timeout=1) @@ -186,7 +186,7 @@ def run_task() -> None: no_result_on_retry=True, default_retry_label=True, types_of_exceptions=(KeyError,), - ) + ), ) runs = 0 @@ -196,7 +196,7 @@ def run_task2() -> None: runs += 1 - raise ValueError() + raise ValueError(runs) task = await run_task2.kiq() resp = await task.wait_result(timeout=1) @@ -214,7 +214,7 @@ async def test_retry_of_custom_exc_types_of_smart_middleware() -> None: no_result_on_retry=True, default_retry_label=True, types_of_exceptions=(KeyError, ValueError), - ) + ), ) runs = 0 @@ -224,7 +224,7 @@ def run_task() -> None: runs += 1 - raise ValueError() + raise ValueError(runs) task = await run_task.kiq() resp = await task.wait_result(timeout=1) @@ -239,7 +239,7 @@ def run_task() -> None: no_result_on_retry=True, default_retry_label=True, types_of_exceptions=(KeyError,), - ) + ), ) runs = 0 @@ -249,7 +249,7 @@ def run_task2() -> None: runs += 1 - raise ValueError() + raise ValueError(runs) task = await run_task2.kiq() resp = await task.wait_result(timeout=1) From 9e4e5854115abea0f8309d1cc3792319005391f6 Mon Sep 17 00:00:00 2001 From: GiperBoreipy Date: Sat, 12 Jul 2025 10:57:47 +0300 Subject: [PATCH 4/5] fix to ruff lint --- taskiq/middlewares/simple_retry_middleware.py | 2 +- taskiq/middlewares/smart_retry_middleware.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/taskiq/middlewares/simple_retry_middleware.py b/taskiq/middlewares/simple_retry_middleware.py index a1a4ec26..01ae077a 100644 --- a/taskiq/middlewares/simple_retry_middleware.py +++ b/taskiq/middlewares/simple_retry_middleware.py @@ -45,7 +45,7 @@ async def on_error( :param exception: found exception. """ if self.types_of_exceptions is not None and not isinstance( - exception, tuple(self.types_of_exceptions) + exception, tuple(self.types_of_exceptions), ): return diff --git a/taskiq/middlewares/smart_retry_middleware.py b/taskiq/middlewares/smart_retry_middleware.py index ea009e86..47dd1378 100644 --- a/taskiq/middlewares/smart_retry_middleware.py +++ b/taskiq/middlewares/smart_retry_middleware.py @@ -1,7 +1,7 @@ import datetime import random from logging import getLogger -from typing import Any, Optional, Iterable +from typing import Any, Iterable, Optional from taskiq import ScheduleSource from taskiq.abc.middleware import TaskiqMiddleware @@ -142,7 +142,7 @@ async def on_error( :param exception: Caught exception. """ if self.types_of_exceptions is not None and not isinstance( - exception, tuple(self.types_of_exceptions) + exception, tuple(self.types_of_exceptions), ): return From 1de0f1bd26e9a9d9ef4dc7b7cc345c4e4afc9dd1 Mon Sep 17 00:00:00 2001 From: GiperBoreipy Date: Sat, 12 Jul 2025 23:38:21 +0300 Subject: [PATCH 5/5] use pre-commit --- taskiq/middlewares/simple_retry_middleware.py | 3 ++- taskiq/middlewares/smart_retry_middleware.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/taskiq/middlewares/simple_retry_middleware.py b/taskiq/middlewares/simple_retry_middleware.py index 01ae077a..ca16088c 100644 --- a/taskiq/middlewares/simple_retry_middleware.py +++ b/taskiq/middlewares/simple_retry_middleware.py @@ -45,7 +45,8 @@ async def on_error( :param exception: found exception. """ if self.types_of_exceptions is not None and not isinstance( - exception, tuple(self.types_of_exceptions), + exception, + tuple(self.types_of_exceptions), ): return diff --git a/taskiq/middlewares/smart_retry_middleware.py b/taskiq/middlewares/smart_retry_middleware.py index 47dd1378..3089d241 100644 --- a/taskiq/middlewares/smart_retry_middleware.py +++ b/taskiq/middlewares/smart_retry_middleware.py @@ -142,7 +142,8 @@ async def on_error( :param exception: Caught exception. """ if self.types_of_exceptions is not None and not isinstance( - exception, tuple(self.types_of_exceptions), + exception, + tuple(self.types_of_exceptions), ): return