|
2 | 2 |
|
3 | 3 | import pytest
|
4 | 4 |
|
5 |
| -from taskiq import InMemoryBroker, SimpleRetryMiddleware |
| 5 | +from taskiq import InMemoryBroker, SimpleRetryMiddleware, SmartRetryMiddleware |
6 | 6 | from taskiq.exceptions import NoResultError
|
7 | 7 |
|
8 | 8 |
|
@@ -151,3 +151,109 @@ def run_task() -> str:
|
151 | 151 |
|
152 | 152 | assert runs == 1
|
153 | 153 | assert str(resp.error) == str(runs)
|
| 154 | + |
| 155 | + |
| 156 | +@pytest.mark.anyio |
| 157 | +async def test_retry_of_custom_exc_types_of_simple_middleware() -> None: |
| 158 | + # test that the passed error will be handled |
| 159 | + broker = InMemoryBroker().with_middlewares( |
| 160 | + SimpleRetryMiddleware( |
| 161 | + no_result_on_retry=True, |
| 162 | + default_retry_label=True, |
| 163 | + types_of_exceptions=(KeyError, ValueError), |
| 164 | + ) |
| 165 | + ) |
| 166 | + runs = 0 |
| 167 | + |
| 168 | + @broker.task(max_retries=10) |
| 169 | + def run_task() -> None: |
| 170 | + nonlocal runs |
| 171 | + |
| 172 | + runs += 1 |
| 173 | + |
| 174 | + raise ValueError() |
| 175 | + |
| 176 | + task = await run_task.kiq() |
| 177 | + resp = await task.wait_result(timeout=1) |
| 178 | + with pytest.raises(ValueError): |
| 179 | + resp.raise_for_error() |
| 180 | + |
| 181 | + assert runs == 10 |
| 182 | + |
| 183 | + # test that an untransmitted error will not be handled |
| 184 | + broker = InMemoryBroker().with_middlewares( |
| 185 | + SimpleRetryMiddleware( |
| 186 | + no_result_on_retry=True, |
| 187 | + default_retry_label=True, |
| 188 | + types_of_exceptions=(KeyError,), |
| 189 | + ) |
| 190 | + ) |
| 191 | + runs = 0 |
| 192 | + |
| 193 | + @broker.task(max_retries=10) |
| 194 | + def run_task2() -> None: |
| 195 | + nonlocal runs |
| 196 | + |
| 197 | + runs += 1 |
| 198 | + |
| 199 | + raise ValueError() |
| 200 | + |
| 201 | + task = await run_task2.kiq() |
| 202 | + resp = await task.wait_result(timeout=1) |
| 203 | + with pytest.raises(ValueError): |
| 204 | + resp.raise_for_error() |
| 205 | + |
| 206 | + assert runs == 1 |
| 207 | + |
| 208 | + |
| 209 | +@pytest.mark.anyio |
| 210 | +async def test_retry_of_custom_exc_types_of_smart_middleware() -> None: |
| 211 | + # test that the passed error will be handled |
| 212 | + broker = InMemoryBroker().with_middlewares( |
| 213 | + SmartRetryMiddleware( |
| 214 | + no_result_on_retry=True, |
| 215 | + default_retry_label=True, |
| 216 | + types_of_exceptions=(KeyError, ValueError), |
| 217 | + ) |
| 218 | + ) |
| 219 | + runs = 0 |
| 220 | + |
| 221 | + @broker.task(max_retries=10) |
| 222 | + def run_task() -> None: |
| 223 | + nonlocal runs |
| 224 | + |
| 225 | + runs += 1 |
| 226 | + |
| 227 | + raise ValueError() |
| 228 | + |
| 229 | + task = await run_task.kiq() |
| 230 | + resp = await task.wait_result(timeout=1) |
| 231 | + with pytest.raises(ValueError): |
| 232 | + resp.raise_for_error() |
| 233 | + |
| 234 | + assert runs == 10 |
| 235 | + |
| 236 | + # test that an untransmitted error will not be handled |
| 237 | + broker = InMemoryBroker().with_middlewares( |
| 238 | + SmartRetryMiddleware( |
| 239 | + no_result_on_retry=True, |
| 240 | + default_retry_label=True, |
| 241 | + types_of_exceptions=(KeyError,), |
| 242 | + ) |
| 243 | + ) |
| 244 | + runs = 0 |
| 245 | + |
| 246 | + @broker.task(max_retries=10) |
| 247 | + def run_task2() -> None: |
| 248 | + nonlocal runs |
| 249 | + |
| 250 | + runs += 1 |
| 251 | + |
| 252 | + raise ValueError() |
| 253 | + |
| 254 | + task = await run_task2.kiq() |
| 255 | + resp = await task.wait_result(timeout=1) |
| 256 | + with pytest.raises(ValueError): |
| 257 | + resp.raise_for_error() |
| 258 | + |
| 259 | + assert runs == 1 |
0 commit comments