1
1
//! Python coroutine implementation, used notably when wrapping `async fn`
2
2
//! with `#[pyfunction]`/`#[pymethods]`.
3
- use std:: task:: Waker ;
4
3
use std:: {
5
4
future:: Future ,
6
5
panic,
7
6
pin:: Pin ,
8
7
sync:: Arc ,
9
- task:: { Context , Poll } ,
8
+ task:: { Context , Poll , Waker } ,
10
9
} ;
11
10
12
11
use pyo3_macros:: { pyclass, pymethods} ;
13
12
14
13
use crate :: {
15
- coroutine:: waker:: AsyncioWaker ,
14
+ coroutine:: waker:: CoroutineWaker ,
16
15
exceptions:: { PyAttributeError , PyRuntimeError , PyStopIteration } ,
17
16
pyclass:: IterNextOutput ,
18
- types:: { PyIterator , PyString } ,
19
- IntoPy , Py , PyAny , PyErr , PyObject , PyResult , Python ,
17
+ types:: PyString ,
18
+ IntoPy , Py , PyErr , PyObject , PyResult , Python ,
20
19
} ;
21
20
21
+ mod asyncio;
22
22
pub ( crate ) mod cancel;
23
- mod waker;
23
+ pub ( crate ) mod waker;
24
24
25
25
use crate :: coroutine:: cancel:: ThrowCallback ;
26
26
use crate :: panic:: PanicException ;
@@ -36,7 +36,7 @@ pub struct Coroutine {
36
36
throw_callback : Option < ThrowCallback > ,
37
37
allow_threads : bool ,
38
38
future : Option < Pin < Box < dyn Future < Output = PyResult < PyObject > > + Send > > > ,
39
- waker : Option < Arc < AsyncioWaker > > ,
39
+ waker : Option < Arc < CoroutineWaker > > ,
40
40
}
41
41
42
42
impl Coroutine {
@@ -73,33 +73,37 @@ impl Coroutine {
73
73
}
74
74
}
75
75
76
- fn poll (
76
+ fn poll_inner (
77
77
& mut self ,
78
78
py : Python < ' _ > ,
79
- throw : Option < PyObject > ,
79
+ mut sent_result : Option < Result < PyObject , PyObject > > ,
80
80
) -> PyResult < IterNextOutput < PyObject , PyObject > > {
81
81
// raise if the coroutine has already been run to completion
82
82
let future_rs = match self . future {
83
83
Some ( ref mut fut) => fut,
84
84
None => return Err ( PyRuntimeError :: new_err ( COROUTINE_REUSED_ERROR ) ) ,
85
85
} ;
86
- // reraise thrown exception it
87
- match ( throw, & self . throw_callback ) {
88
- ( Some ( exc) , Some ( cb) ) => cb. throw ( exc. as_ref ( py) ) ,
89
- ( Some ( exc) , None ) => {
90
- self . close ( ) ;
91
- return Err ( PyErr :: from_value ( exc. as_ref ( py) ) ) ;
86
+ // if the future is not pending on a Python awaitable,
87
+ // execute throw callback or complete on close
88
+ if !matches ! ( self . waker, Some ( ref w) if w. yielded_from_awaitable( py) ) {
89
+ match ( sent_result, & self . throw_callback ) {
90
+ ( res @ Some ( Ok ( _) ) , _) => sent_result = res,
91
+ ( Some ( Err ( err) ) , Some ( cb) ) => {
92
+ cb. throw ( err. as_ref ( py) ) ;
93
+ sent_result = Some ( Ok ( py. None ( ) . into ( ) ) ) ;
94
+ }
95
+ ( Some ( Err ( err) ) , None ) => return Err ( PyErr :: from_value ( err. as_ref ( py) ) ) ,
96
+ ( None , _) => return Ok ( IterNextOutput :: Return ( py. None ( ) . into ( ) ) ) ,
92
97
}
93
- _ => { }
94
98
}
95
99
// create a new waker, or try to reset it in place
96
100
if let Some ( waker) = self . waker . as_mut ( ) . and_then ( Arc :: get_mut) {
97
- waker. reset ( ) ;
101
+ waker. reset ( sent_result ) ;
98
102
} else {
99
- self . waker = Some ( Arc :: new ( AsyncioWaker :: new ( ) ) ) ;
103
+ self . waker = Some ( Arc :: new ( CoroutineWaker :: new ( sent_result ) ) ) ;
100
104
}
101
105
let waker = Waker :: from ( self . waker . clone ( ) . unwrap ( ) ) ;
102
- // poll the Rust future and forward its results if ready
106
+ // poll the Rust future and forward its results if ready; otherwise, yield from waker
103
107
// polling is UnwindSafe because the future is dropped in case of panic
104
108
let poll = || {
105
109
if self . allow_threads {
@@ -109,29 +113,27 @@ impl Coroutine {
109
113
}
110
114
} ;
111
115
match panic:: catch_unwind ( panic:: AssertUnwindSafe ( poll) ) {
112
- Ok ( Poll :: Ready ( res) ) => {
113
- self . close ( ) ;
114
- return Ok ( IterNextOutput :: Return ( res?) ) ;
115
- }
116
- Err ( err) => {
117
- self . close ( ) ;
118
- return Err ( PanicException :: from_panic_payload ( err) ) ;
119
- }
120
- _ => { }
116
+ Err ( err) => Err ( PanicException :: from_panic_payload ( err) ) ,
117
+ Ok ( Poll :: Ready ( res) ) => Ok ( IterNextOutput :: Return ( res?) ) ,
118
+ Ok ( Poll :: Pending ) => match self . waker . as_ref ( ) . unwrap ( ) . yield_ ( py) {
119
+ Ok ( to_yield) => Ok ( IterNextOutput :: Yield ( to_yield) ) ,
120
+ Err ( err) => Err ( err) ,
121
+ } ,
121
122
}
122
- // otherwise, initialize the waker `asyncio.Future`
123
- if let Some ( future) = self . waker . as_ref ( ) . unwrap ( ) . initialize_future ( py) ? {
124
- // `asyncio.Future` must be awaited; fortunately, it implements `__iter__ = __await__`
125
- // and will yield itself if its result has not been set in polling above
126
- if let Some ( future) = PyIterator :: from_object ( future) . unwrap ( ) . next ( ) {
127
- // future has not been leaked into Python for now, and Rust code can only call
128
- // `set_result(None)` in `Wake` implementation, so it's safe to unwrap
129
- return Ok ( IterNextOutput :: Yield ( future. unwrap ( ) . into ( ) ) ) ;
130
- }
123
+ }
124
+
125
+ fn poll (
126
+ & mut self ,
127
+ py : Python < ' _ > ,
128
+ sent_result : Option < Result < PyObject , PyObject > > ,
129
+ ) -> PyResult < IterNextOutput < PyObject , PyObject > > {
130
+ let result = self . poll_inner ( py, sent_result) ;
131
+ if matches ! ( result, Ok ( IterNextOutput :: Return ( _) ) | Err ( _) ) {
132
+ // the Rust future is dropped, and the field set to `None`
133
+ // to indicate the coroutine has been run to completion
134
+ drop ( self . future . take ( ) ) ;
131
135
}
132
- // if waker has been waken during future polling, this is roughly equivalent to
133
- // `await asyncio.sleep(0)`, so just yield `None`.
134
- Ok ( IterNextOutput :: Yield ( py. None ( ) . into ( ) ) )
136
+ result
135
137
}
136
138
}
137
139
@@ -163,25 +165,24 @@ impl Coroutine {
163
165
}
164
166
}
165
167
166
- fn send ( & mut self , py : Python < ' _ > , _value : & PyAny ) -> PyResult < PyObject > {
167
- iter_result ( self . poll ( py, None ) ?)
168
+ fn send ( & mut self , py : Python < ' _ > , value : PyObject ) -> PyResult < PyObject > {
169
+ iter_result ( self . poll ( py, Some ( Ok ( value ) ) ) ?)
168
170
}
169
171
170
172
fn throw ( & mut self , py : Python < ' _ > , exc : PyObject ) -> PyResult < PyObject > {
171
- iter_result ( self . poll ( py, Some ( exc) ) ?)
173
+ iter_result ( self . poll ( py, Some ( Err ( exc) ) ) ?)
172
174
}
173
175
174
- fn close ( & mut self ) {
175
- // the Rust future is dropped, and the field set to `None`
176
- // to indicate the coroutine has been run to completion
177
- drop ( self . future . take ( ) ) ;
176
+ fn close ( & mut self , py : Python < ' _ > ) -> PyResult < ( ) > {
177
+ self . poll ( py, None ) ?;
178
+ Ok ( ( ) )
178
179
}
179
180
180
181
fn __await__ ( self_ : Py < Self > ) -> Py < Self > {
181
182
self_
182
183
}
183
184
184
185
fn __next__ ( & mut self , py : Python < ' _ > ) -> PyResult < IterNextOutput < PyObject , PyObject > > {
185
- self . poll ( py, None )
186
+ self . poll ( py, Some ( Ok ( py . None ( ) . into ( ) ) ) )
186
187
}
187
188
}
0 commit comments