diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index c15fb93..16c7bd3 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -1,9 +1,11 @@ import asyncio +import contextvars +import functools import inspect from concurrent.futures import Executor from logging import getLogger from time import time -from typing import Any, Callable, Dict, List, Optional, Set, Union, get_type_hints +from typing import Any, Callable, Dict, Optional, Set, Union, get_type_hints import anyio from taskiq_dependencies import DependencyGraph @@ -23,25 +25,6 @@ QUEUE_DONE = b"-1" -def _run_sync( - target: Callable[..., Any], - args: List[Any], - kwargs: Dict[str, Any], -) -> Any: - """ - Runs function synchronously. - - We use this function, because - we cannot pass kwargs in loop.run_with_executor(). - - :param target: function to execute. - :param args: list of function's args. - :param kwargs: dict of function's kwargs. - :return: result of function's execution. - """ - return target(*args, **kwargs) - - class Receiver: """Class that uses as a callback handler.""" @@ -255,13 +238,13 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 else: is_coroutine = False # If this is a synchronous function, we - # run it in executor. + # run it in executor and preserve the context. + ctx = contextvars.copy_context() + func = functools.partial(target, *message.args, **kwargs) target_future = loop.run_in_executor( self.executor, - _run_sync, - target, - message.args, - kwargs, + ctx.run, + func, ) timeout = message.labels.get("timeout") if timeout is not None: diff --git a/tests/receiver/test_receiver.py b/tests/receiver/test_receiver.py index 57637e9..2fd3f83 100644 --- a/tests/receiver/test_receiver.py +++ b/tests/receiver/test_receiver.py @@ -1,8 +1,9 @@ import asyncio +import contextvars import random import time from concurrent.futures import ThreadPoolExecutor -from typing import Any, ClassVar, List, Optional +from typing import Any, ClassVar, Generator, List, Optional import pytest from taskiq_dependencies import Depends @@ -472,3 +473,64 @@ async def task_no_result() -> str: assert resp.return_value is None assert not broker._running_tasks assert isinstance(resp.error, ValueError) + + +EXPECTED_CTX_VALUE = 42 + + +@pytest.fixture() +def ctxvar() -> Generator[contextvars.ContextVar[int], None, None]: + _ctx_variable: contextvars.ContextVar[int] = contextvars.ContextVar( + "taskiq_test_ctx_var", + ) + token = _ctx_variable.set(EXPECTED_CTX_VALUE) + yield _ctx_variable + _ctx_variable.reset(token) + + +@pytest.mark.anyio +async def test_run_task_successful_sync_preserve_contextvars( + ctxvar: contextvars.ContextVar[int], +) -> None: + """Running sync tasks should preserve context vars.""" + + def test_func() -> int: + return ctxvar.get() + + receiver = get_receiver() + + result = await receiver.run_task( + test_func, + TaskiqMessage( + task_id="", + task_name="", + labels={}, + args=[], + kwargs={}, + ), + ) + assert result.return_value == EXPECTED_CTX_VALUE + + +@pytest.mark.anyio +async def test_run_task_successful_async_preserve_contextvars( + ctxvar: contextvars.ContextVar[int], +) -> None: + """Running async tasks should preserve context vars.""" + + async def test_func() -> int: + return ctxvar.get() + + receiver = get_receiver() + + result = await receiver.run_task( + test_func, + TaskiqMessage( + task_id="", + task_name="", + labels={}, + args=[], + kwargs={}, + ), + ) + assert result.return_value == EXPECTED_CTX_VALUE