Skip to content

Commit 150b654

Browse files
committed
feat: expose coroutine constructor
1 parent be65412 commit 150b654

File tree

7 files changed

+119
-61
lines changed

7 files changed

+119
-61
lines changed

guide/src/async-await.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,24 @@ To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_
9696

9797
Each `coroutine.send` call is translated to `Future::poll` call. If a [`CancelHandle`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html) parameter is declared, the exception passed to `coroutine.throw` call is stored in it and can be retrieved with [`CancelHandle::cancelled`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html#method.cancelled); otherwise, it cancels the Rust future, and the exception is reraised;
9898

99-
*The type does not yet have a public constructor until the design is finalized.*
99+
Coroutine can also be instantiated directly
100+
101+
```rust
102+
# # ![allow(dead_code)]
103+
use pyo3::prelude::*;
104+
use pyo3::coroutine::{CancelHandle, Coroutine};
105+
106+
#[pyfunction]
107+
fn new_coroutine(py: Python<'_>) -> Coroutine {
108+
let mut cancel = CancelHandle::new();
109+
let throw_callback = cancel.throw_callback();
110+
let future = async move {
111+
cancel.cancelled().await;
112+
PyResult::Ok(())
113+
};
114+
Coroutine::new(pyo3::intern!(py, "my_coro"), future)
115+
.with_qualname_prefix("MyClass")
116+
.with_throw_callback(throw_callback)
117+
.with_allow_threads(true)
118+
}
119+
```

newsfragments/3613.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Expose `Coroutine` constructor

pyo3-macros-backend/src/method.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -504,13 +504,13 @@ impl<'a> FnSpec<'a> {
504504
};
505505
let mut call = quote! {{
506506
let future = #future;
507-
_pyo3::impl_::coroutine::new_coroutine(
507+
_pyo3::coroutine::Coroutine::new(
508508
_pyo3::intern!(py, stringify!(#python_name)),
509-
#qualname_prefix,
510-
#throw_callback,
511-
#allow_threads,
512509
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) },
513510
)
511+
.with_qualname_prefix(#qualname_prefix)
512+
.with_throw_callback(#throw_callback)
513+
.with_allow_threads(#allow_threads)
514514
}};
515515
if cancel_handle.is_some() {
516516
call = quote! {{

src/coroutine.rs

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use pyo3_macros::{pyclass, pymethods};
1212

1313
use crate::{
1414
coroutine::waker::CoroutineWaker,
15-
exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration},
15+
exceptions::{PyRuntimeError, PyStopIteration},
1616
pyclass::IterNextOutput,
1717
types::PyString,
1818
IntoPy, Py, PyErr, PyObject, PyResult, Python,
@@ -26,20 +26,19 @@ pub(crate) mod cancel;
2626
mod trio;
2727
pub(crate) mod waker;
2828

29-
use crate::coroutine::cancel::ThrowCallback;
3029
use crate::panic::PanicException;
31-
pub use cancel::CancelHandle;
30+
pub use cancel::{CancelHandle, ThrowCallback};
3231

3332
const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";
3433

3534
/// Python coroutine wrapping a [`Future`].
3635
#[pyclass(crate = "crate")]
3736
pub struct Coroutine {
38-
name: Option<Py<PyString>>,
37+
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
38+
name: Py<PyString>,
3939
qualname_prefix: Option<&'static str>,
4040
throw_callback: Option<ThrowCallback>,
4141
allow_threads: bool,
42-
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
4342
waker: Option<Arc<CoroutineWaker>>,
4443
}
4544

@@ -50,13 +49,7 @@ impl Coroutine {
5049
/// (should always be `None` anyway).
5150
///
5251
/// `Coroutine `throw` drop the wrapped future and reraise the exception passed
53-
pub(crate) fn new<F, T, E>(
54-
name: Option<Py<PyString>>,
55-
qualname_prefix: Option<&'static str>,
56-
throw_callback: Option<ThrowCallback>,
57-
allow_threads: bool,
58-
future: F,
59-
) -> Self
52+
pub fn new<F, T, E>(name: impl Into<Py<PyString>>, future: F) -> Self
6053
where
6154
F: Future<Output = Result<T, E>> + Send + 'static,
6255
T: IntoPy<PyObject>,
@@ -68,15 +61,36 @@ impl Coroutine {
6861
Ok(obj.into_py(unsafe { Python::assume_gil_acquired() }))
6962
};
7063
Self {
71-
name,
72-
qualname_prefix,
73-
throw_callback,
74-
allow_threads,
7564
future: Some(Box::pin(wrap)),
65+
name: name.into(),
66+
qualname_prefix: None,
67+
throw_callback: None,
68+
allow_threads: false,
7669
waker: None,
7770
}
7871
}
7972

73+
/// Set a prefix for `__qualname__`, which will be joined with a "."
74+
pub fn with_qualname_prefix(mut self, prefix: impl Into<Option<&'static str>>) -> Self {
75+
self.qualname_prefix = prefix.into();
76+
self
77+
}
78+
79+
/// Register a callback for coroutine `throw` method.
80+
///
81+
/// The exception passed to `throw` is then redirected to this callback, notifying the
82+
/// associated [`CancelHandle`], without being reraised.
83+
pub fn with_throw_callback(mut self, callback: impl Into<Option<ThrowCallback>>) -> Self {
84+
self.throw_callback = callback.into();
85+
self
86+
}
87+
88+
/// Release the GIL while polling the future wrapped.
89+
pub fn with_allow_threads(mut self, allow_threads: bool) -> Self {
90+
self.allow_threads = allow_threads;
91+
self
92+
}
93+
8094
fn poll_inner(
8195
&mut self,
8296
py: Python<'_>,
@@ -151,22 +165,18 @@ pub(crate) fn iter_result(result: IterNextOutput<PyObject, PyObject>) -> PyResul
151165
#[pymethods(crate = "crate")]
152166
impl Coroutine {
153167
#[getter]
154-
fn __name__(&self, py: Python<'_>) -> PyResult<Py<PyString>> {
155-
match &self.name {
156-
Some(name) => Ok(name.clone_ref(py)),
157-
None => Err(PyAttributeError::new_err("__name__")),
158-
}
168+
fn __name__(&self, py: Python<'_>) -> Py<PyString> {
169+
self.name.clone_ref(py)
159170
}
160171

161172
#[getter]
162173
fn __qualname__(&self, py: Python<'_>) -> PyResult<Py<PyString>> {
163-
match (&self.name, &self.qualname_prefix) {
164-
(Some(name), Some(prefix)) => Ok(format!("{}.{}", prefix, name.as_ref(py).to_str()?)
174+
Ok(match &self.qualname_prefix {
175+
Some(prefix) => format!("{}.{}", prefix, self.name.as_ref(py).to_str()?)
165176
.as_str()
166-
.into_py(py)),
167-
(Some(name), None) => Ok(name.clone_ref(py)),
168-
(None, _) => Err(PyAttributeError::new_err("__qualname__")),
169-
}
177+
.into_py(py),
178+
None => self.name.clone_ref(py),
179+
})
170180
}
171181

172182
fn send(&mut self, py: Python<'_>, value: PyObject) -> PyResult<PyObject> {

src/coroutine/cancel.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ impl CancelHandle {
5656
Cancelled(self).await
5757
}
5858

59-
#[doc(hidden)]
59+
/// Instantiate a [`ThrowCallback`] associated to this cancel handle.
6060
pub fn throw_callback(&self) -> ThrowCallback {
6161
ThrowCallback(self.0.clone())
6262
}
@@ -71,7 +71,7 @@ impl Future for Cancelled<'_> {
7171
}
7272
}
7373

74-
#[doc(hidden)]
74+
/// Callback for coroutine `throw` method, notifying the associated [`CancelHandle`]
7575
pub struct ThrowCallback(Arc<Inner>);
7676

7777
impl ThrowCallback {

src/impl_/coroutine.rs

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,8 @@
11
use std::future::Future;
22
use std::mem;
33

4-
use crate::coroutine::cancel::ThrowCallback;
54
use crate::pyclass::boolean_struct::False;
6-
use crate::{
7-
coroutine::Coroutine, types::PyString, IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject,
8-
PyRef, PyRefMut, PyResult, Python,
9-
};
10-
11-
pub fn new_coroutine<F, T, E>(
12-
name: &PyString,
13-
qualname_prefix: Option<&'static str>,
14-
throw_callback: Option<ThrowCallback>,
15-
allow_threads: bool,
16-
future: F,
17-
) -> Coroutine
18-
where
19-
F: Future<Output = Result<T, E>> + Send + 'static,
20-
T: IntoPy<PyObject>,
21-
E: Into<PyErr>,
22-
{
23-
Coroutine::new(
24-
Some(name.into()),
25-
qualname_prefix,
26-
throw_callback,
27-
allow_threads,
28-
future,
29-
)
30-
}
5+
use crate::{Py, PyAny, PyCell, PyClass, PyRef, PyRefMut, PyResult, Python};
316

327
fn get_ptr<T: PyClass>(obj: &Py<T>) -> *mut T {
338
// SAFETY: Py<T> can be casted as *const PyCell<T>

tests/test_coroutine.rs

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
#![cfg(feature = "macros")]
22
#![cfg(not(target_arch = "wasm32"))]
33
use std::ops::Deref;
4+
use std::sync::Arc;
45
use std::{task::Poll, thread, time::Duration};
56

67
use futures::{channel::oneshot, future::poll_fn, FutureExt};
7-
use pyo3::coroutine::CancelHandle;
8+
use pyo3::coroutine::{CancelHandle, Coroutine};
9+
use pyo3::sync::GILOnceCell;
810
use pyo3::types::{IntoPyDict, PyType};
911
use pyo3::{prelude::*, py_run};
1012

@@ -183,3 +185,53 @@ fn test_async_method_receiver() {
183185
async fn method_mut(&mut self) {}
184186
}
185187
}
188+
189+
#[test]
190+
fn multi_thread_event_loop() {
191+
Python::with_gil(|gil| {
192+
let sleep = wrap_pyfunction!(sleep, gil).unwrap();
193+
let test = r#"
194+
import asyncio
195+
import threading
196+
loop = asyncio.new_event_loop()
197+
# spawn the sleep task and run just one iteration of the event loop
198+
# to schedule the sleep wakeup
199+
task = loop.create_task(sleep(0.1))
200+
loop.stop()
201+
loop.run_forever()
202+
assert not task.done()
203+
# spawn a thread to complete the execution of the sleep task
204+
def target(loop, task):
205+
loop.run_until_complete(task)
206+
thread = threading.Thread(target=target, args=(loop, task))
207+
thread.start()
208+
thread.join()
209+
assert task.result() == 42
210+
"#;
211+
py_run!(gil, sleep, test);
212+
})
213+
}
214+
215+
#[test]
216+
fn closed_event_loop() {
217+
let waker = Arc::new(GILOnceCell::new());
218+
let waker2 = waker.clone();
219+
let future = poll_fn(move |cx| {
220+
Python::with_gil(|gil| waker2.set(gil, cx.waker().clone()).unwrap());
221+
Poll::Pending::<PyResult<()>>
222+
});
223+
Python::with_gil(|gil| {
224+
let register_waker = Coroutine::new("register_waker".into_py(gil), future).into_py(gil);
225+
let test = r#"
226+
import asyncio
227+
loop = asyncio.new_event_loop()
228+
# register a waker by spawning a task and polling it once, then close the loop
229+
task = loop.create_task(register_waker)
230+
loop.stop()
231+
loop.run_forever()
232+
loop.close()
233+
"#;
234+
py_run!(gil, register_waker, test);
235+
Python::with_gil(|gil| waker.get(gil).unwrap().wake_by_ref())
236+
})
237+
}

0 commit comments

Comments
 (0)