diff --git a/taskiq/middlewares/simple_retry_middleware.py b/taskiq/middlewares/simple_retry_middleware.py index 0eaee1b..ca16088 100644 --- a/taskiq/middlewares/simple_retry_middleware.py +++ b/taskiq/middlewares/simple_retry_middleware.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Any +from typing import Any, Iterable, Optional from taskiq.abc.middleware import TaskiqMiddleware from taskiq.exceptions import NoResultError @@ -18,10 +18,12 @@ def __init__( default_retry_count: int = 3, default_retry_label: bool = False, no_result_on_retry: bool = True, + types_of_exceptions: Optional[Iterable[type[BaseException]]] = None, ) -> None: self.default_retry_count = default_retry_count self.default_retry_label = default_retry_label self.no_result_on_retry = no_result_on_retry + self.types_of_exceptions = types_of_exceptions async def on_error( self, @@ -42,6 +44,12 @@ async def on_error( :param result: execution result. :param exception: found exception. """ + if self.types_of_exceptions is not None and not isinstance( + exception, + tuple(self.types_of_exceptions), + ): + return + # Valid exception if isinstance(exception, NoResultError): return diff --git a/taskiq/middlewares/smart_retry_middleware.py b/taskiq/middlewares/smart_retry_middleware.py index ef0bcb6..3089d24 100644 --- a/taskiq/middlewares/smart_retry_middleware.py +++ b/taskiq/middlewares/smart_retry_middleware.py @@ -1,7 +1,7 @@ import datetime import random from logging import getLogger -from typing import Any, Optional +from typing import Any, Iterable, Optional from taskiq import ScheduleSource from taskiq.abc.middleware import TaskiqMiddleware @@ -35,6 +35,7 @@ def __init__( use_delay_exponent: bool = False, max_delay_exponent: float = 60, schedule_source: Optional[ScheduleSource] = None, + types_of_exceptions: Optional[Iterable[type[BaseException]]] = None, ) -> None: """ Initialize retry middleware. @@ -48,6 +49,7 @@ def __init__( :param max_delay_exponent: Maximum allowed delay when using backoff. :param schedule_source: Schedule source to use for scheduling. If None, the default broker will be used. + :param types_of_exceptions: Types of exceptions to retry from. """ super().__init__() self.default_retry_count = default_retry_count @@ -58,6 +60,7 @@ def __init__( self.use_delay_exponent = use_delay_exponent self.max_delay_exponent = max_delay_exponent self.schedule_source = schedule_source + self.types_of_exceptions = types_of_exceptions if not isinstance(schedule_source, (ScheduleSource, type(None))): raise TypeError( @@ -138,6 +141,12 @@ async def on_error( :param result: Execution result. :param exception: Caught exception. """ + if self.types_of_exceptions is not None and not isinstance( + exception, + tuple(self.types_of_exceptions), + ): + return + if isinstance(exception, NoResultError): return diff --git a/tests/middlewares/test_task_retry.py b/tests/middlewares/test_task_retry.py index 7544ab6..8d5dfca 100644 --- a/tests/middlewares/test_task_retry.py +++ b/tests/middlewares/test_task_retry.py @@ -2,7 +2,7 @@ import pytest -from taskiq import InMemoryBroker, SimpleRetryMiddleware +from taskiq import InMemoryBroker, SimpleRetryMiddleware, SmartRetryMiddleware from taskiq.exceptions import NoResultError @@ -151,3 +151,109 @@ def run_task() -> str: assert runs == 1 assert str(resp.error) == str(runs) + + +@pytest.mark.anyio +async def test_retry_of_custom_exc_types_of_simple_middleware() -> None: + # test that the passed error will be handled + broker = InMemoryBroker().with_middlewares( + SimpleRetryMiddleware( + no_result_on_retry=True, + default_retry_label=True, + types_of_exceptions=(KeyError, ValueError), + ), + ) + runs = 0 + + @broker.task(max_retries=10) + def run_task() -> None: + nonlocal runs + + runs += 1 + + raise ValueError(runs) + + task = await run_task.kiq() + resp = await task.wait_result(timeout=1) + with pytest.raises(ValueError): + resp.raise_for_error() + + assert runs == 10 + + # test that an untransmitted error will not be handled + broker = InMemoryBroker().with_middlewares( + SimpleRetryMiddleware( + no_result_on_retry=True, + default_retry_label=True, + types_of_exceptions=(KeyError,), + ), + ) + runs = 0 + + @broker.task(max_retries=10) + def run_task2() -> None: + nonlocal runs + + runs += 1 + + raise ValueError(runs) + + task = await run_task2.kiq() + resp = await task.wait_result(timeout=1) + with pytest.raises(ValueError): + resp.raise_for_error() + + assert runs == 1 + + +@pytest.mark.anyio +async def test_retry_of_custom_exc_types_of_smart_middleware() -> None: + # test that the passed error will be handled + broker = InMemoryBroker().with_middlewares( + SmartRetryMiddleware( + no_result_on_retry=True, + default_retry_label=True, + types_of_exceptions=(KeyError, ValueError), + ), + ) + runs = 0 + + @broker.task(max_retries=10) + def run_task() -> None: + nonlocal runs + + runs += 1 + + raise ValueError(runs) + + task = await run_task.kiq() + resp = await task.wait_result(timeout=1) + with pytest.raises(ValueError): + resp.raise_for_error() + + assert runs == 10 + + # test that an untransmitted error will not be handled + broker = InMemoryBroker().with_middlewares( + SmartRetryMiddleware( + no_result_on_retry=True, + default_retry_label=True, + types_of_exceptions=(KeyError,), + ), + ) + runs = 0 + + @broker.task(max_retries=10) + def run_task2() -> None: + nonlocal runs + + runs += 1 + + raise ValueError(runs) + + task = await run_task2.kiq() + resp = await task.wait_result(timeout=1) + with pytest.raises(ValueError): + resp.raise_for_error() + + assert runs == 1