|
9 | 9 | import sys
|
10 | 10 | import threading
|
11 | 11 | import warnings
|
| 12 | +from contextvars import ContextVar |
12 | 13 | from pathlib import Path
|
13 | 14 | from types import FrameType
|
14 | 15 | from typing import Any, Awaitable, Callable, TypeVar, cast
|
@@ -126,6 +127,7 @@ def run(self, coro: Any) -> Any:
|
126 | 127 |
|
127 | 128 |
|
128 | 129 | _runner_map: dict[str, _TaskRunner] = {}
|
| 130 | +_loop: ContextVar[asyncio.AbstractEventLoop | None] = ContextVar("_loop", default=None) |
129 | 131 |
|
130 | 132 |
|
131 | 133 | def run_sync(coro: Callable[..., Awaitable[T]]) -> Callable[..., T]:
|
@@ -159,22 +161,30 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
|
159 | 161 | pass
|
160 | 162 |
|
161 | 163 | # Run the loop for this thread.
|
162 |
| - # In Python 3.12, a deprecation warning is raised, which |
163 |
| - # may later turn into a RuntimeError. We handle both |
164 |
| - # cases. |
165 |
| - with warnings.catch_warnings(): |
166 |
| - warnings.simplefilter("ignore", DeprecationWarning) |
167 |
| - try: |
168 |
| - loop = asyncio.get_event_loop() |
169 |
| - except RuntimeError: |
170 |
| - loop = asyncio.new_event_loop() |
171 |
| - asyncio.set_event_loop(loop) |
172 |
| - return loop.run_until_complete(inner) |
| 164 | + loop = ensure_event_loop() |
| 165 | + return loop.run_until_complete(inner) |
173 | 166 |
|
174 | 167 | wrapped.__doc__ = coro.__doc__
|
175 | 168 | return wrapped
|
176 | 169 |
|
177 | 170 |
|
| 171 | +def ensure_event_loop(prefer_selector_loop: bool = False) -> asyncio.AbstractEventLoop: |
| 172 | + # Get the loop for this thread, or create a new one. |
| 173 | + loop = _loop.get() |
| 174 | + if loop is not None and not loop.is_closed(): |
| 175 | + return loop |
| 176 | + try: |
| 177 | + loop = asyncio.get_running_loop() |
| 178 | + except RuntimeError: |
| 179 | + if sys.platform == "win32" and prefer_selector_loop: |
| 180 | + loop = asyncio.WindowsSelectorEventLoopPolicy().new_event_loop() |
| 181 | + else: |
| 182 | + loop = asyncio.new_event_loop() |
| 183 | + asyncio.set_event_loop(loop) |
| 184 | + _loop.set(loop) |
| 185 | + return loop |
| 186 | + |
| 187 | + |
178 | 188 | async def ensure_async(obj: Awaitable[T] | T) -> T:
|
179 | 189 | """Convert a non-awaitable object to a coroutine if needed,
|
180 | 190 | and await it if it was not already awaited.
|
|
0 commit comments