@@ -14,11 +14,7 @@ use std::{fmt, io};
1414use tracing:: trace;
1515
1616/// Mock implementation of `std::thread::JoinHandle`.
17- pub struct JoinHandle < T > {
18- result : Arc < Mutex < Option < std:: thread:: Result < T > > > > ,
19- notify : rt:: Notify ,
20- thread : Thread ,
21- }
17+ pub struct JoinHandle < T > ( JoinHandleInner < ' static , T > ) ;
2218
2319/// Mock implementation of `std::thread::Thread`.
2420#[ derive( Clone , Debug ) ]
@@ -128,7 +124,7 @@ where
128124 F : ' static ,
129125 T : ' static ,
130126{
131- spawn_internal ( f, None , location ! ( ) )
127+ JoinHandle ( spawn_internal_static ( f, None , location ! ( ) ) )
132128}
133129
134130/// Mock implementation of `std::thread::park`.
@@ -142,38 +138,6 @@ pub fn park() {
142138 rt:: park ( location ! ( ) ) ;
143139}
144140
145- fn spawn_internal < F , T > ( f : F , name : Option < String > , location : Location ) -> JoinHandle < T >
146- where
147- F : FnOnce ( ) -> T ,
148- F : ' static ,
149- T : ' static ,
150- {
151- let result = Arc :: new ( Mutex :: new ( None ) ) ;
152- let notify = rt:: Notify :: new ( true , false ) ;
153-
154- let id = {
155- let name = name. clone ( ) ;
156- let result = result. clone ( ) ;
157- rt:: spawn ( move || {
158- rt:: execution ( |execution| {
159- init_current ( execution, name) ;
160- } ) ;
161-
162- * result. lock ( ) . unwrap ( ) = Some ( Ok ( f ( ) ) ) ;
163- notify. notify ( location) ;
164- } )
165- } ;
166-
167- JoinHandle {
168- result,
169- notify,
170- thread : Thread {
171- id : ThreadId { id } ,
172- name,
173- } ,
174- }
175- }
176-
177141impl Builder {
178142 /// Generates the base configuration for spawning a thread, from which
179143 /// configuration methods can be chained.
@@ -206,21 +170,40 @@ impl Builder {
206170 F : Send + ' static ,
207171 T : Send + ' static ,
208172 {
209- Ok ( spawn_internal ( f, self . name , location ! ( ) ) )
173+ Ok ( JoinHandle ( spawn_internal_static ( f, self . name , location ! ( ) ) ) )
174+ }
175+ }
176+
177+ impl Builder {
178+ /// Spawns a new scoped thread using the settings set through this `Builder`.
179+ pub fn spawn_scoped < ' scope , ' env , F , T > (
180+ self ,
181+ scope : & ' scope Scope < ' scope , ' env > ,
182+ f : F ,
183+ ) -> io:: Result < ScopedJoinHandle < ' scope , T > >
184+ where
185+ F : FnOnce ( ) -> T + Send + ' scope ,
186+ T : Send + ' scope ,
187+ {
188+ Ok ( ScopedJoinHandle (
189+ // Safety: the call to this function requires a `&'scope Scope`
190+ // which can only be constructed by `scope()`, which ensures that
191+ // all spawned threads are joined before the `Scope` is destroyed.
192+ unsafe { spawn_internal ( f, self . name , Some ( scope. data . clone ( ) ) , location ! ( ) ) } ,
193+ ) )
210194 }
211195}
212196
213197impl < T > JoinHandle < T > {
214198 /// Waits for the associated thread to finish.
215199 #[ track_caller]
216200 pub fn join ( self ) -> std:: thread:: Result < T > {
217- self . notify . wait ( location ! ( ) ) ;
218- self . result . lock ( ) . unwrap ( ) . take ( ) . unwrap ( )
201+ self . 0 . join ( )
219202 }
220203
221204 /// Gets a handle to the underlying [`Thread`]
222205 pub fn thread ( & self ) -> & Thread {
223- & self . thread
206+ self . 0 . thread ( )
224207 }
225208}
226209
@@ -301,3 +284,220 @@ impl<T: 'static> fmt::Debug for LocalKey<T> {
301284 f. pad ( "LocalKey { .. }" )
302285 }
303286}
287+
288+ /// A scope for spawning scoped threads.
289+ ///
290+ /// See [`scope`] for more details.
291+ #[ derive( Debug ) ]
292+ pub struct Scope < ' scope , ' env : ' scope > {
293+ data : Arc < ScopeData > ,
294+ scope : PhantomData < & ' scope mut & ' scope ( ) > ,
295+ env : PhantomData < & ' env mut & ' env ( ) > ,
296+ }
297+
298+ /// An owned permission to join on a scoped thread (block on its termination).
299+ ///
300+ /// See [`Scope::spawn`] for details.
301+ #[ derive( Debug ) ]
302+ pub struct ScopedJoinHandle < ' scope , T > ( JoinHandleInner < ' scope , T > ) ;
303+
304+ /// Create a scope for spawning scoped threads.
305+ ///
306+ /// Mock implementation of [`std::thread::scope`].
307+ #[ track_caller]
308+ pub fn scope < ' env , F , T > ( f : F ) -> T
309+ where
310+ F : for < ' scope > FnOnce ( & ' scope Scope < ' scope , ' env > ) -> T ,
311+ {
312+ let scope = Scope {
313+ data : Arc :: new ( ScopeData {
314+ running_threads : Mutex :: default ( ) ,
315+ main_thread : current ( ) ,
316+ } ) ,
317+ env : PhantomData ,
318+ scope : PhantomData ,
319+ } ;
320+
321+ // Run `f`, but catch panics so we can make sure to wait for all the threads to join.
322+ let result = std:: panic:: catch_unwind ( std:: panic:: AssertUnwindSafe ( || f ( & scope) ) ) ;
323+
324+ // Wait until all the threads are finished. This is required to fulfill
325+ // the safety requirements of `spawn_internal`.
326+ let running = loop {
327+ {
328+ let running = scope. data . running_threads . lock ( ) . unwrap ( ) ;
329+ if running. count == 0 {
330+ break running;
331+ }
332+ }
333+ park ( ) ;
334+ } ;
335+
336+ for notify in & running. notify_on_finished {
337+ notify. wait ( location ! ( ) )
338+ }
339+
340+ // Throw any panic from `f`, or the return value of `f` if no thread panicked.
341+ match result {
342+ Err ( e) => std:: panic:: resume_unwind ( e) ,
343+ Ok ( result) => result,
344+ }
345+ }
346+
347+ impl < ' scope , ' env > Scope < ' scope , ' env > {
348+ /// Spawns a new thread within a scope, returning a [`ScopedJoinHandle`] for it.
349+ ///
350+ /// See [`std::thread::Scope`] and [`std::thread::scope`] for details.
351+ pub fn spawn < F , T > ( & ' scope self , f : F ) -> ScopedJoinHandle < ' scope , T >
352+ where
353+ F : FnOnce ( ) -> T + Send + ' scope ,
354+ T : Send + ' scope ,
355+ {
356+ Builder :: new ( )
357+ . spawn_scoped ( self , f)
358+ . expect ( "failed to spawn thread" )
359+ }
360+ }
361+
362+ impl < ' scope , T > ScopedJoinHandle < ' scope , T > {
363+ /// Extracts a handle to the underlying thread.
364+ pub fn thread ( & self ) -> & Thread {
365+ self . 0 . thread ( )
366+ }
367+
368+ /// Waits for the associated thread to finish.
369+ pub fn join ( self ) -> std:: thread:: Result < T > {
370+ self . 0 . join ( )
371+ }
372+ }
373+
374+ /// Handle for joining on a thread with a scope.
375+ #[ derive( Debug ) ]
376+ struct JoinHandleInner < ' scope , T > {
377+ data : Arc < ThreadData < ' scope , T > > ,
378+ notify : rt:: Notify ,
379+ thread : Thread ,
380+ }
381+
382+ /// Spawns a thread without a local scope.
383+ fn spawn_internal_static < F , T > (
384+ f : F ,
385+ name : Option < String > ,
386+ location : Location ,
387+ ) -> JoinHandleInner < ' static , T >
388+ where
389+ F : FnOnce ( ) -> T ,
390+ F : ' static ,
391+ T : ' static ,
392+ {
393+ // Safety: the requirements of `spawn_internal` are trivially satisfied
394+ // since there is no `scope`.
395+ unsafe { spawn_internal ( f, name, None , location) }
396+ }
397+
398+ /// Spawns a thread with an optional scope.
399+ ///
400+ /// The caller must ensure that if `scope` is not None, the provided closure
401+ /// finishes before `'scope` ends.
402+ unsafe fn spawn_internal < ' scope , F , T > (
403+ f : F ,
404+ name : Option < String > ,
405+ scope : Option < Arc < ScopeData > > ,
406+ location : Location ,
407+ ) -> JoinHandleInner < ' scope , T >
408+ where
409+ F : FnOnce ( ) -> T ,
410+ F : ' scope ,
411+ T : ' scope ,
412+ {
413+ let scope_notify = scope
414+ . clone ( )
415+ . map ( |scope| ( scope. add_running_thread ( ) , scope) ) ;
416+ let thread_data = Arc :: new ( ThreadData :: new ( ) ) ;
417+ let notify = rt:: Notify :: new ( true , false ) ;
418+
419+ let id = {
420+ let name = name. clone ( ) ;
421+ let thread_data = thread_data. clone ( ) ;
422+ let body: Box < dyn FnOnce ( ) + ' scope > = Box :: new ( move || {
423+ rt:: execution ( |execution| {
424+ init_current ( execution, name) ;
425+ } ) ;
426+
427+ * thread_data. result . lock ( ) . unwrap ( ) = Some ( Ok ( f ( ) ) ) ;
428+ notify. notify ( location) ;
429+
430+ if let Some ( ( notifier, scope) ) = scope_notify {
431+ notifier. notify ( location ! ( ) ) ;
432+ scope. remove_running_thread ( )
433+ }
434+ } ) ;
435+ rt:: spawn ( std:: mem:: transmute :: < _ , Box < dyn FnOnce ( ) > > ( body) )
436+ } ;
437+
438+ JoinHandleInner {
439+ data : thread_data,
440+ notify,
441+ thread : Thread {
442+ id : ThreadId { id } ,
443+ name,
444+ } ,
445+ }
446+ }
447+
448+ /// Data for a running thread.
449+ #[ derive( Debug ) ]
450+ struct ThreadData < ' scope , T > {
451+ result : Mutex < Option < std:: thread:: Result < T > > > ,
452+ _marker : PhantomData < Option < & ' scope ScopeData > > ,
453+ }
454+
455+ impl < ' scope , T > ThreadData < ' scope , T > {
456+ fn new ( ) -> Self {
457+ Self {
458+ result : Mutex :: new ( None ) ,
459+ _marker : PhantomData ,
460+ }
461+ }
462+ }
463+
464+ impl < ' scope , T > JoinHandleInner < ' scope , T > {
465+ fn join ( self ) -> std:: thread:: Result < T > {
466+ self . notify . wait ( location ! ( ) ) ;
467+ self . data . result . lock ( ) . unwrap ( ) . take ( ) . unwrap ( )
468+ }
469+
470+ fn thread ( & self ) -> & Thread {
471+ & self . thread
472+ }
473+ }
474+
475+ #[ derive( Default , Debug ) ]
476+ struct ScopeThreads {
477+ count : usize ,
478+ notify_on_finished : Vec < rt:: Notify > ,
479+ }
480+
481+ #[ derive( Debug ) ]
482+ struct ScopeData {
483+ running_threads : Mutex < ScopeThreads > ,
484+ main_thread : Thread ,
485+ }
486+
487+ impl ScopeData {
488+ fn add_running_thread ( & self ) -> rt:: Notify {
489+ let mut running = self . running_threads . lock ( ) . unwrap ( ) ;
490+ running. count += 1 ;
491+ let notify = rt:: Notify :: new ( true , false ) ;
492+ running. notify_on_finished . push ( notify) ;
493+ notify
494+ }
495+
496+ fn remove_running_thread ( & self ) {
497+ let mut running = self . running_threads . lock ( ) . unwrap ( ) ;
498+ running. count -= 1 ;
499+ if running. count == 0 {
500+ self . main_thread . unpark ( )
501+ }
502+ }
503+ }
0 commit comments