|
1 | 1 | import pickle
|
2 |
| -from typing import Any, Dict, Optional, TypeVar, Union |
| 2 | +import sys |
| 3 | +from contextlib import asynccontextmanager |
| 4 | +from typing import ( |
| 5 | + TYPE_CHECKING, |
| 6 | + Any, |
| 7 | + AsyncIterator, |
| 8 | + Dict, |
| 9 | + List, |
| 10 | + Optional, |
| 11 | + Tuple, |
| 12 | + TypeVar, |
| 13 | + Union, |
| 14 | +) |
3 | 15 |
|
4 |
| -from redis.asyncio import BlockingConnectionPool, Redis |
| 16 | +from redis.asyncio import BlockingConnectionPool, Redis, Sentinel |
5 | 17 | from redis.asyncio.cluster import RedisCluster
|
6 | 18 | from taskiq import AsyncResultBackend
|
7 | 19 | from taskiq.abc.result_backend import TaskiqResult
|
| 20 | +from taskiq.abc.serializer import TaskiqSerializer |
8 | 21 |
|
9 | 22 | from taskiq_redis.exceptions import (
|
10 | 23 | DuplicateExpireTimeSelectedError,
|
11 | 24 | ExpireTimeMustBeMoreThanZeroError,
|
12 | 25 | ResultIsMissingError,
|
13 | 26 | )
|
| 27 | +from taskiq_redis.serializer import PickleSerializer |
| 28 | + |
| 29 | +if sys.version_info >= (3, 10): |
| 30 | + from typing import TypeAlias |
| 31 | +else: |
| 32 | + from typing_extensions import TypeAlias |
| 33 | + |
| 34 | +if TYPE_CHECKING: |
| 35 | + _Redis: TypeAlias = Redis[bytes] |
| 36 | +else: |
| 37 | + _Redis: TypeAlias = Redis |
14 | 38 |
|
15 | 39 | _ReturnType = TypeVar("_ReturnType")
|
16 | 40 |
|
@@ -267,3 +291,142 @@ async def get_result(
|
267 | 291 | taskiq_result.log = None
|
268 | 292 |
|
269 | 293 | return taskiq_result
|
| 294 | + |
| 295 | + |
| 296 | +class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]): |
| 297 | + """Async result based on redis sentinel.""" |
| 298 | + |
| 299 | + def __init__( |
| 300 | + self, |
| 301 | + sentinels: List[Tuple[str, int]], |
| 302 | + master_name: str, |
| 303 | + keep_results: bool = True, |
| 304 | + result_ex_time: Optional[int] = None, |
| 305 | + result_px_time: Optional[int] = None, |
| 306 | + min_other_sentinels: int = 0, |
| 307 | + sentinel_kwargs: Optional[Any] = None, |
| 308 | + serializer: Optional[TaskiqSerializer] = None, |
| 309 | + **connection_kwargs: Any, |
| 310 | + ) -> None: |
| 311 | + """ |
| 312 | + Constructs a new result backend. |
| 313 | +
|
| 314 | + :param sentinels: list of sentinel host and ports pairs. |
| 315 | + :param master_name: sentinel master name. |
| 316 | + :param keep_results: flag to not remove results from Redis after reading. |
| 317 | + :param result_ex_time: expire time in seconds for result. |
| 318 | + :param result_px_time: expire time in milliseconds for result. |
| 319 | + :param max_connection_pool_size: maximum number of connections in pool. |
| 320 | + :param connection_kwargs: additional arguments for redis BlockingConnectionPool. |
| 321 | +
|
| 322 | + :raises DuplicateExpireTimeSelectedError: if result_ex_time |
| 323 | + and result_px_time are selected. |
| 324 | + :raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time |
| 325 | + and result_px_time are equal zero. |
| 326 | + """ |
| 327 | + self.sentinel = Sentinel( |
| 328 | + sentinels=sentinels, |
| 329 | + min_other_sentinels=min_other_sentinels, |
| 330 | + sentinel_kwargs=sentinel_kwargs, |
| 331 | + **connection_kwargs, |
| 332 | + ) |
| 333 | + self.master_name = master_name |
| 334 | + if serializer is None: |
| 335 | + serializer = PickleSerializer() |
| 336 | + self.serializer = serializer |
| 337 | + self.keep_results = keep_results |
| 338 | + self.result_ex_time = result_ex_time |
| 339 | + self.result_px_time = result_px_time |
| 340 | + |
| 341 | + unavailable_conditions = any( |
| 342 | + ( |
| 343 | + self.result_ex_time is not None and self.result_ex_time <= 0, |
| 344 | + self.result_px_time is not None and self.result_px_time <= 0, |
| 345 | + ), |
| 346 | + ) |
| 347 | + if unavailable_conditions: |
| 348 | + raise ExpireTimeMustBeMoreThanZeroError( |
| 349 | + "You must select one expire time param and it must be more than zero.", |
| 350 | + ) |
| 351 | + |
| 352 | + if self.result_ex_time and self.result_px_time: |
| 353 | + raise DuplicateExpireTimeSelectedError( |
| 354 | + "Choose either result_ex_time or result_px_time.", |
| 355 | + ) |
| 356 | + |
| 357 | + @asynccontextmanager |
| 358 | + async def _acquire_master_conn(self) -> AsyncIterator[_Redis]: |
| 359 | + async with self.sentinel.master_for(self.master_name) as redis_conn: |
| 360 | + yield redis_conn |
| 361 | + |
| 362 | + async def set_result( |
| 363 | + self, |
| 364 | + task_id: str, |
| 365 | + result: TaskiqResult[_ReturnType], |
| 366 | + ) -> None: |
| 367 | + """ |
| 368 | + Sets task result in redis. |
| 369 | +
|
| 370 | + Dumps TaskiqResult instance into the bytes and writes |
| 371 | + it to redis. |
| 372 | +
|
| 373 | + :param task_id: ID of the task. |
| 374 | + :param result: TaskiqResult instance. |
| 375 | + """ |
| 376 | + redis_set_params: Dict[str, Union[str, bytes, int]] = { |
| 377 | + "name": task_id, |
| 378 | + "value": self.serializer.dumpb(result), |
| 379 | + } |
| 380 | + if self.result_ex_time: |
| 381 | + redis_set_params["ex"] = self.result_ex_time |
| 382 | + elif self.result_px_time: |
| 383 | + redis_set_params["px"] = self.result_px_time |
| 384 | + |
| 385 | + async with self._acquire_master_conn() as redis: |
| 386 | + await redis.set(**redis_set_params) # type: ignore |
| 387 | + |
| 388 | + async def is_result_ready(self, task_id: str) -> bool: |
| 389 | + """ |
| 390 | + Returns whether the result is ready. |
| 391 | +
|
| 392 | + :param task_id: ID of the task. |
| 393 | +
|
| 394 | + :returns: True if the result is ready else False. |
| 395 | + """ |
| 396 | + async with self._acquire_master_conn() as redis: |
| 397 | + return bool(await redis.exists(task_id)) |
| 398 | + |
| 399 | + async def get_result( |
| 400 | + self, |
| 401 | + task_id: str, |
| 402 | + with_logs: bool = False, |
| 403 | + ) -> TaskiqResult[_ReturnType]: |
| 404 | + """ |
| 405 | + Gets result from the task. |
| 406 | +
|
| 407 | + :param task_id: task's id. |
| 408 | + :param with_logs: if True it will download task's logs. |
| 409 | + :raises ResultIsMissingError: if there is no result when trying to get it. |
| 410 | + :return: task's return value. |
| 411 | + """ |
| 412 | + async with self._acquire_master_conn() as redis: |
| 413 | + if self.keep_results: |
| 414 | + result_value = await redis.get( |
| 415 | + name=task_id, |
| 416 | + ) |
| 417 | + else: |
| 418 | + result_value = await redis.getdel( |
| 419 | + name=task_id, |
| 420 | + ) |
| 421 | + |
| 422 | + if result_value is None: |
| 423 | + raise ResultIsMissingError |
| 424 | + |
| 425 | + taskiq_result: TaskiqResult[_ReturnType] = pickle.loads( # noqa: S301 |
| 426 | + result_value, |
| 427 | + ) |
| 428 | + |
| 429 | + if not with_logs: |
| 430 | + taskiq_result.log = None |
| 431 | + |
| 432 | + return taskiq_result |
0 commit comments