Skip to content

Commit 55a0e82

Browse files
committed
feat: add PyFuture to await Python awaitables
1 parent a7679ec commit 55a0e82

File tree

11 files changed

+605
-131
lines changed

11 files changed

+605
-131
lines changed

guide/src/SUMMARY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
- [Python exceptions](exception.md)
2121
- [Calling Python from Rust](python_from_rust.md)
2222
- [Using `async` and `await`](async-await.md)
23+
- [Awaiting Python awaitables](async-await/pyfuture.md)
2324
- [GIL, mutability and object types](types.md)
2425
- [Parallelism](parallelism.md)
2526
- [Debugging](debugging.md)

guide/src/async-await/pyfuture.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Awaiting Python awaitables
2+
3+
Python awaitable can be awaited on Rust side using [`PyFuture`]({{#PYO3_DOCS_URL}}/pyo3/types/struct.PyFuture.html).
4+
5+
```rust
6+
# #![allow(dead_code)]
7+
use pyo3::{prelude::*, types::PyFuture};
8+
9+
#[pyfunction]
10+
async fn wrap_awaitable(awaitable: PyObject) -> PyResult<PyObject> {
11+
let future: Py<PyFuture> = Python::with_gil(|gil| Py::from_object(gil, awaitable))?;
12+
future.await
13+
}
14+
```
15+
16+
`PyFuture::from_object` construct a `PyFuture` from a Python awaitable object, by calling its `__await__` method (or `__iter__` for generator-based coroutine).
17+
18+
## Restrictions
19+
20+
`PyFuture` can only be awaited in the context of a PyO3 coroutine. Otherwise, it panics.
21+
22+
```rust
23+
# #![allow(dead_code)]
24+
use pyo3::{prelude::*, types::PyFuture};
25+
26+
#[pyfunction]
27+
fn block_on(awaitable: PyObject) -> PyResult<PyObject> {
28+
let future: Py<PyFuture> = Python::with_gil(|gil| Py::from_object(gil, awaitable))?;
29+
futures::executor::block_on(future) // ERROR: PyFuture must be awaited in coroutine context
30+
}
31+
```
32+
33+
`PyFuture` must be the only Rust future awaited; it means that it's forbidden to `select!` a `Pyfuture`. Otherwise, it panics.
34+
35+
```rust
36+
# #![allow(dead_code)]
37+
use std::future;
38+
use futures::FutureExt;
39+
use pyo3::{prelude::*, types::PyFuture};
40+
41+
#[pyfunction]
42+
async fn select(awaitable: PyObject) -> PyResult<PyObject> {
43+
let future: Py<PyFuture> = Python::with_gil(|gil| Py::from_object(gil, awaitable))?;
44+
futures::select_biased! {
45+
_ = future::pending::<()>().fuse() => unreachable!(),
46+
res = future.fuse() => res, // ERROR: Python awaitable mixed with Rust future
47+
}
48+
}
49+
```
50+
51+
These restrictions exist because awaiting a `PyFuture` strongly binds it to the enclosing coroutine. The coroutine will then delegate its `send`/`throw`/`close` methods to the awaited `PyFuture`. If it was awaited in a `select!`, `Coroutine::send` would no able to know if the value passed would have to be delegated to the `Pyfuture` or not.

pyo3-ffi/src/abstract_.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,11 @@ extern "C" {
128128
pub fn PyIter_Next(arg1: *mut PyObject) -> *mut PyObject;
129129
#[cfg(all(not(PyPy), Py_3_10))]
130130
#[cfg_attr(PyPy, link_name = "PyPyIter_Send")]
131-
pub fn PyIter_Send(iter: *mut PyObject, arg: *mut PyObject, presult: *mut *mut PyObject);
131+
pub fn PyIter_Send(
132+
iter: *mut PyObject,
133+
arg: *mut PyObject,
134+
presult: *mut *mut PyObject,
135+
) -> c_int;
132136

133137
#[cfg_attr(PyPy, link_name = "PyPyNumber_Check")]
134138
pub fn PyNumber_Check(o: *mut PyObject) -> c_int;

src/coroutine.rs

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
//! Python coroutine implementation, used notably when wrapping `async fn`
22
//! with `#[pyfunction]`/`#[pymethods]`.
3-
use std::task::Waker;
43
use std::{
54
future::Future,
65
panic,
76
pin::Pin,
87
sync::Arc,
9-
task::{Context, Poll},
8+
task::{Context, Poll, Waker},
109
};
1110

1211
use pyo3_macros::{pyclass, pymethods};
1312

1413
use crate::{
15-
coroutine::waker::AsyncioWaker,
14+
coroutine::waker::CoroutineWaker,
1615
exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration},
1716
pyclass::IterNextOutput,
18-
types::{PyIterator, PyString},
19-
IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python,
17+
types::PyString,
18+
IntoPy, Py, PyErr, PyObject, PyResult, Python,
2019
};
2120

21+
mod asyncio;
2222
pub(crate) mod cancel;
23-
mod waker;
23+
pub(crate) mod waker;
2424

2525
use crate::coroutine::cancel::ThrowCallback;
2626
use crate::panic::PanicException;
@@ -36,7 +36,7 @@ pub struct Coroutine {
3636
throw_callback: Option<ThrowCallback>,
3737
allow_threads: bool,
3838
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
39-
waker: Option<Arc<AsyncioWaker>>,
39+
waker: Option<Arc<CoroutineWaker>>,
4040
}
4141

4242
impl Coroutine {
@@ -73,33 +73,37 @@ impl Coroutine {
7373
}
7474
}
7575

76-
fn poll(
76+
fn poll_inner(
7777
&mut self,
7878
py: Python<'_>,
79-
throw: Option<PyObject>,
79+
mut sent_result: Option<Result<PyObject, PyObject>>,
8080
) -> PyResult<IterNextOutput<PyObject, PyObject>> {
8181
// raise if the coroutine has already been run to completion
8282
let future_rs = match self.future {
8383
Some(ref mut fut) => fut,
8484
None => return Err(PyRuntimeError::new_err(COROUTINE_REUSED_ERROR)),
8585
};
86-
// reraise thrown exception it
87-
match (throw, &self.throw_callback) {
88-
(Some(exc), Some(cb)) => cb.throw(exc.as_ref(py)),
89-
(Some(exc), None) => {
90-
self.close();
91-
return Err(PyErr::from_value(exc.as_ref(py)));
86+
// if the future is not pending on a Python awaitable,
87+
// execute throw callback or complete on close
88+
if !matches!(self.waker, Some(ref w) if w.yielded_from_awaitable(py)) {
89+
match (sent_result, &self.throw_callback) {
90+
(res @ Some(Ok(_)), _) => sent_result = res,
91+
(Some(Err(err)), Some(cb)) => {
92+
cb.throw(err.as_ref(py));
93+
sent_result = Some(Ok(py.None().into()));
94+
}
95+
(Some(Err(err)), None) => return Err(PyErr::from_value(err.as_ref(py))),
96+
(None, _) => return Ok(IterNextOutput::Return(py.None().into())),
9297
}
93-
_ => {}
9498
}
9599
// create a new waker, or try to reset it in place
96100
if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {
97-
waker.reset();
101+
waker.reset(sent_result);
98102
} else {
99-
self.waker = Some(Arc::new(AsyncioWaker::new()));
103+
self.waker = Some(Arc::new(CoroutineWaker::new(sent_result)));
100104
}
101105
let waker = Waker::from(self.waker.clone().unwrap());
102-
// poll the Rust future and forward its results if ready
106+
// poll the Rust future and forward its results if ready; otherwise, yield from waker
103107
// polling is UnwindSafe because the future is dropped in case of panic
104108
let poll = || {
105109
if self.allow_threads {
@@ -109,29 +113,27 @@ impl Coroutine {
109113
}
110114
};
111115
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
112-
Ok(Poll::Ready(res)) => {
113-
self.close();
114-
return Ok(IterNextOutput::Return(res?));
115-
}
116-
Err(err) => {
117-
self.close();
118-
return Err(PanicException::from_panic_payload(err));
119-
}
120-
_ => {}
116+
Err(err) => Err(PanicException::from_panic_payload(err)),
117+
Ok(Poll::Ready(res)) => Ok(IterNextOutput::Return(res?)),
118+
Ok(Poll::Pending) => match self.waker.as_ref().unwrap().yield_(py) {
119+
Ok(to_yield) => Ok(IterNextOutput::Yield(to_yield)),
120+
Err(err) => Err(err),
121+
},
121122
}
122-
// otherwise, initialize the waker `asyncio.Future`
123-
if let Some(future) = self.waker.as_ref().unwrap().initialize_future(py)? {
124-
// `asyncio.Future` must be awaited; fortunately, it implements `__iter__ = __await__`
125-
// and will yield itself if its result has not been set in polling above
126-
if let Some(future) = PyIterator::from_object(future).unwrap().next() {
127-
// future has not been leaked into Python for now, and Rust code can only call
128-
// `set_result(None)` in `Wake` implementation, so it's safe to unwrap
129-
return Ok(IterNextOutput::Yield(future.unwrap().into()));
130-
}
123+
}
124+
125+
fn poll(
126+
&mut self,
127+
py: Python<'_>,
128+
sent_result: Option<Result<PyObject, PyObject>>,
129+
) -> PyResult<IterNextOutput<PyObject, PyObject>> {
130+
let result = self.poll_inner(py, sent_result);
131+
if matches!(result, Ok(IterNextOutput::Return(_)) | Err(_)) {
132+
// the Rust future is dropped, and the field set to `None`
133+
// to indicate the coroutine has been run to completion
134+
drop(self.future.take());
131135
}
132-
// if waker has been waken during future polling, this is roughly equivalent to
133-
// `await asyncio.sleep(0)`, so just yield `None`.
134-
Ok(IterNextOutput::Yield(py.None().into()))
136+
result
135137
}
136138
}
137139

@@ -163,25 +165,24 @@ impl Coroutine {
163165
}
164166
}
165167

166-
fn send(&mut self, py: Python<'_>, _value: &PyAny) -> PyResult<PyObject> {
167-
iter_result(self.poll(py, None)?)
168+
fn send(&mut self, py: Python<'_>, value: PyObject) -> PyResult<PyObject> {
169+
iter_result(self.poll(py, Some(Ok(value)))?)
168170
}
169171

170172
fn throw(&mut self, py: Python<'_>, exc: PyObject) -> PyResult<PyObject> {
171-
iter_result(self.poll(py, Some(exc))?)
173+
iter_result(self.poll(py, Some(Err(exc)))?)
172174
}
173175

174-
fn close(&mut self) {
175-
// the Rust future is dropped, and the field set to `None`
176-
// to indicate the coroutine has been run to completion
177-
drop(self.future.take());
176+
fn close(&mut self, py: Python<'_>) -> PyResult<()> {
177+
self.poll(py, None)?;
178+
Ok(())
178179
}
179180

180181
fn __await__(self_: Py<Self>) -> Py<Self> {
181182
self_
182183
}
183184

184185
fn __next__(&mut self, py: Python<'_>) -> PyResult<IterNextOutput<PyObject, PyObject>> {
185-
self.poll(py, None)
186+
self.poll(py, Some(Ok(py.None().into())))
186187
}
187188
}

src/coroutine/asyncio.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
//! Coroutine implementation compatible with asyncio.
2+
use crate::sync::GILOnceCell;
3+
use crate::types::{PyCFunction, PyIterator};
4+
use crate::{intern, wrap_pyfunction, IntoPy, Py, PyAny, PyObject, PyResult, Python};
5+
use pyo3_macros::pyfunction;
6+
7+
/// `asyncio.get_running_loop`
8+
fn get_running_loop(py: Python<'_>) -> PyResult<&PyAny> {
9+
static GET_RUNNING_LOOP: GILOnceCell<PyObject> = GILOnceCell::new();
10+
let import = || -> PyResult<_> {
11+
let module = py.import("asyncio")?;
12+
Ok(module.getattr("get_running_loop")?.into())
13+
};
14+
GET_RUNNING_LOOP
15+
.get_or_try_init(py, import)?
16+
.as_ref(py)
17+
.call0()
18+
}
19+
20+
/// Asyncio-compatible coroutine waker.
21+
///
22+
/// Polling a Rust future yields an `asyncio.Future`, whose `set_result` method is called
23+
/// when `Waker::wake` is called.
24+
pub(super) struct AsyncioWaker {
25+
event_loop: PyObject,
26+
future: PyObject,
27+
}
28+
29+
impl AsyncioWaker {
30+
pub(super) fn new(py: Python<'_>) -> PyResult<Self> {
31+
let event_loop = get_running_loop(py)?.into_py(py);
32+
let future = event_loop.call_method0(py, "create_future")?;
33+
Ok(Self { event_loop, future })
34+
}
35+
36+
pub(super) fn yield_(&self, py: Python<'_>) -> PyResult<PyObject> {
37+
let __await__;
38+
// `asyncio.Future` must be awaited; in normal case, it implements `__iter__ = __await__`,
39+
// but `create_future` may have been overriden
40+
let mut iter = match PyIterator::from_object(self.future.as_ref(py)) {
41+
Ok(iter) => iter,
42+
Err(_) => {
43+
__await__ = self.future.call_method0(py, intern!(py, "__await__"))?;
44+
PyIterator::from_object(__await__.as_ref(py))?
45+
}
46+
};
47+
// future has not been waken (because `yield_waken` would have been called
48+
// otherwise), so it is expected to yield itself
49+
Ok(iter.next().expect("future didn't yield")?.into_py(py))
50+
}
51+
52+
pub(super) fn yield_waken(py: Python<'_>) -> PyResult<PyObject> {
53+
Ok(py.None().into())
54+
}
55+
56+
pub(super) fn wake(&self, py: Python<'_>) -> PyResult<()> {
57+
static RELEASE_WAITER: GILOnceCell<Py<PyCFunction>> = GILOnceCell::new();
58+
let release_waiter = RELEASE_WAITER
59+
.get_or_try_init(py, || wrap_pyfunction!(release_waiter, py).map(Into::into))?;
60+
// `Future.set_result` must be called in event loop thread,
61+
// so it requires `call_soon_threadsafe`
62+
let call_soon_threadsafe = self.event_loop.call_method1(
63+
py,
64+
intern!(py, "call_soon_threadsafe"),
65+
(release_waiter, self.future.as_ref(py)),
66+
);
67+
if let Err(err) = call_soon_threadsafe {
68+
// `call_soon_threadsafe` will raise if the event loop is closed;
69+
// instead of catching an unspecific `RuntimeError`, check directly if it's closed.
70+
let is_closed = self.event_loop.call_method0(py, "is_closed")?;
71+
if !is_closed.extract(py)? {
72+
return Err(err);
73+
}
74+
}
75+
Ok(())
76+
}
77+
}
78+
79+
/// Call `future.set_result` if the future is not done.
80+
///
81+
/// Future can be cancelled by the event loop before being waken.
82+
/// See <https://github.com/python/cpython/blob/main/Lib/asyncio/tasks.py#L452C5-L452C5>
83+
#[pyfunction(crate = "crate")]
84+
fn release_waiter(future: &PyAny) -> PyResult<()> {
85+
let done = future.call_method0(intern!(future.py(), "done"))?;
86+
if !done.extract::<bool>()? {
87+
future.call_method1(intern!(future.py(), "set_result"), (future.py().None(),))?;
88+
}
89+
Ok(())
90+
}

0 commit comments

Comments
 (0)