@@ -234,6 +234,162 @@ use std::sync::{Arc, Condvar, Mutex, MutexGuard, RwLock, TryLockError};
234
234
use std:: time:: { Duration , Instant } ;
235
235
use std:: { env, thread} ;
236
236
237
+ #[ cfg( feature = "async" ) ]
238
+ mod async_imp {
239
+ use super :: * ;
240
+ use futures:: future:: BoxFuture ;
241
+
242
+ #[ derive( Clone ) ]
243
+ pub ( crate ) struct AsyncCallback (
244
+ Arc < dyn Fn ( ) -> BoxFuture < ' static , ( ) > + Send + Sync + ' static > ,
245
+ ) ;
246
+
247
+ impl Debug for AsyncCallback {
248
+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
249
+ f. write_str ( "AsyncCallback()" )
250
+ }
251
+ }
252
+
253
+ impl PartialEq for AsyncCallback {
254
+ #[ allow( clippy:: vtable_address_comparisons) ]
255
+ fn eq ( & self , other : & Self ) -> bool {
256
+ Arc :: ptr_eq ( & self . 0 , & other. 0 )
257
+ }
258
+ }
259
+
260
+ impl AsyncCallback {
261
+ fn new ( f : impl Fn ( ) -> BoxFuture < ' static , ( ) > + Send + Sync + ' static ) -> AsyncCallback {
262
+ AsyncCallback ( Arc :: new ( f) )
263
+ }
264
+
265
+ async fn run ( & self ) {
266
+ let callback = & self . 0 ;
267
+ callback ( ) . await ;
268
+ }
269
+ }
270
+
271
+ /// `fail_point` but with support for async callback.
272
+ #[ macro_export]
273
+ #[ cfg( all( feature = "failpoints" , feature = "async" ) ) ]
274
+ macro_rules! async_fail_point {
275
+ ( $name: expr) => { {
276
+ $crate:: async_eval( $name, |_| {
277
+ panic!( "Return is not supported for the fail point \" {}\" " , $name) ;
278
+ } )
279
+ . await ;
280
+ } } ;
281
+ ( $name: expr, $e: expr) => { {
282
+ if let Some ( res) = $crate:: async_eval( $name, $e) . await {
283
+ return res;
284
+ }
285
+ } } ;
286
+ ( $name: expr, $cond: expr, $e: expr) => { {
287
+ if $cond {
288
+ $crate:: async_fail_point!( $name, $e) ;
289
+ }
290
+ } } ;
291
+ }
292
+
293
+ /// Configures an async callback to be triggered at the specified
294
+ /// failpoint. If the failpoint is not implemented using
295
+ /// `async_fail_point`, the execution will raise an exception.
296
+ pub fn cfg_async_callback < S , F > ( name : S , f : F ) -> Result < ( ) , String >
297
+ where
298
+ S : Into < String > ,
299
+ F : Fn ( ) -> BoxFuture < ' static , ( ) > + Send + Sync + ' static ,
300
+ {
301
+ let mut registry = REGISTRY . registry . write ( ) . unwrap ( ) ;
302
+ let p = registry
303
+ . entry ( name. into ( ) )
304
+ . or_insert_with ( || Arc :: new ( FailPoint :: new ( ) ) ) ;
305
+ let action = Action :: from_async_callback ( f) ;
306
+ let actions = vec ! [ action] ;
307
+ p. set_actions ( "callback" , actions) ;
308
+ Ok ( ( ) )
309
+ }
310
+
311
+ #[ doc( hidden) ]
312
+ pub async fn async_eval < R , F : FnOnce ( Option < String > ) -> R > ( name : & str , f : F ) -> Option < R > {
313
+ let p = {
314
+ let registry = REGISTRY . registry . read ( ) . unwrap ( ) ;
315
+ match registry. get ( name) {
316
+ None => return None ,
317
+ Some ( p) => p. clone ( ) ,
318
+ }
319
+ } ;
320
+ p. async_eval ( name) . await . map ( f)
321
+ }
322
+
323
+ impl Action {
324
+ #[ cfg( feature = "async" ) ]
325
+ fn from_async_callback (
326
+ f : impl Fn ( ) -> BoxFuture < ' static , ( ) > + Send + Sync + ' static ,
327
+ ) -> Action {
328
+ let task = Task :: CallbackAsync ( AsyncCallback :: new ( f) ) ;
329
+ Action {
330
+ task,
331
+ freq : 1.0 ,
332
+ count : None ,
333
+ }
334
+ }
335
+ }
336
+
337
+ impl FailPoint {
338
+ #[ cfg_attr( feature = "cargo-clippy" , allow( clippy:: option_option) ) ]
339
+ async fn async_eval ( & self , name : & str ) -> Option < Option < String > > {
340
+ let task = {
341
+ let task = self
342
+ . actions
343
+ . read ( )
344
+ . unwrap ( )
345
+ . iter ( )
346
+ . filter_map ( Action :: get_task)
347
+ . next ( ) ;
348
+ match task {
349
+ Some ( Task :: Pause ) => {
350
+ // let n = self.async_pause_notify.clone();
351
+ self . async_pause_notify . notified ( ) . await ;
352
+ return None ;
353
+ }
354
+ Some ( t) => t,
355
+ None => return None ,
356
+ }
357
+ } ;
358
+
359
+ match task {
360
+ Task :: Off => { }
361
+ Task :: Return ( s) => return Some ( s) ,
362
+ Task :: Sleep ( _) => panic ! (
363
+ "fail does not support async sleep, please use a async closure to sleep."
364
+ ) ,
365
+ Task :: Panic ( msg) => match msg {
366
+ Some ( ref msg) => panic ! ( "{}" , msg) ,
367
+ None => panic ! ( "failpoint {} panic" , name) ,
368
+ } ,
369
+ Task :: Print ( msg) => match msg {
370
+ Some ( ref msg) => log:: info!( "{}" , msg) ,
371
+ None => log:: info!( "failpoint {} executed." , name) ,
372
+ } ,
373
+ Task :: Pause => unreachable ! ( ) ,
374
+ Task :: Yield => thread:: yield_now ( ) ,
375
+ Task :: Delay ( _) => panic ! (
376
+ "fail does not support async delay, please use a async closure to sleep."
377
+ ) ,
378
+ Task :: Callback ( f) => {
379
+ f. run ( ) ;
380
+ }
381
+ Task :: CallbackAsync ( f) => {
382
+ f. run ( ) . await ;
383
+ }
384
+ }
385
+ None
386
+ }
387
+ }
388
+ }
389
+
390
+ #[ cfg( feature = "async" ) ]
391
+ pub use async_imp:: * ;
392
+
237
393
#[ derive( Clone ) ]
238
394
struct SyncCallback ( Arc < dyn Fn ( ) + Send + Sync > ) ;
239
395
@@ -282,6 +438,8 @@ enum Task {
282
438
Delay ( u64 ) ,
283
439
/// Call callback function.
284
440
Callback ( SyncCallback ) ,
441
+ #[ cfg( feature = "async" ) ]
442
+ CallbackAsync ( async_imp:: AsyncCallback ) ,
285
443
}
286
444
287
445
#[ derive( Debug ) ]
@@ -433,6 +591,8 @@ impl FromStr for Action {
433
591
struct FailPoint {
434
592
pause : Mutex < bool > ,
435
593
pause_notifier : Condvar ,
594
+ #[ cfg( feature = "async" ) ]
595
+ async_pause_notify : tokio:: sync:: Notify ,
436
596
actions : RwLock < Vec < Action > > ,
437
597
actions_str : RwLock < String > ,
438
598
}
@@ -443,13 +603,16 @@ impl FailPoint {
443
603
FailPoint {
444
604
pause : Mutex :: new ( false ) ,
445
605
pause_notifier : Condvar :: new ( ) ,
606
+ #[ cfg( feature = "async" ) ]
607
+ async_pause_notify : tokio:: sync:: Notify :: new ( ) ,
446
608
actions : RwLock :: default ( ) ,
447
609
actions_str : RwLock :: default ( ) ,
448
610
}
449
611
}
450
612
451
613
fn set_actions ( & self , actions_str : & str , actions : Vec < Action > ) {
452
614
loop {
615
+ self . async_pause_notify . notify_waiters ( ) ;
453
616
// TODO: maybe busy waiting here.
454
617
match self . actions . try_write ( ) {
455
618
Err ( TryLockError :: WouldBlock ) => { }
@@ -460,9 +623,11 @@ impl FailPoint {
460
623
}
461
624
Err ( e) => panic ! ( "unexpected poison: {:?}" , e) ,
462
625
}
463
- let mut guard = self . pause . lock ( ) . unwrap ( ) ;
464
- * guard = false ;
465
- self . pause_notifier . notify_all ( ) ;
626
+ {
627
+ let mut guard = self . pause . lock ( ) . unwrap ( ) ;
628
+ * guard = false ;
629
+ self . pause_notifier . notify_all ( ) ;
630
+ }
466
631
}
467
632
}
468
633
@@ -509,6 +674,7 @@ impl FailPoint {
509
674
Task :: Callback ( f) => {
510
675
f. run ( ) ;
511
676
}
677
+ Task :: CallbackAsync ( _) => unreachable ! ( ) ,
512
678
}
513
679
None
514
680
}
@@ -1062,4 +1228,45 @@ mod tests {
1062
1228
assert_eq ! ( rx. recv_timeout( Duration :: from_millis( 500 ) ) . unwrap( ) , 0 ) ;
1063
1229
assert_eq ! ( f1( ) , 0 ) ;
1064
1230
}
1231
+
1232
+ #[ cfg_attr( not( all( feature = "failpoints" , feature = "async" ) ) , ignore) ]
1233
+ #[ tokio:: test]
1234
+ async fn test_async_failpoint ( ) {
1235
+ use std:: time:: Duration ;
1236
+
1237
+ let f1 = async {
1238
+ async_fail_point ! ( "cb" ) ;
1239
+ } ;
1240
+ let f2 = async {
1241
+ async_fail_point ! ( "cb" ) ;
1242
+ } ;
1243
+
1244
+ let counter = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
1245
+ let counter2 = counter. clone ( ) ;
1246
+ cfg_async_callback ( "cb" , move || {
1247
+ counter2. fetch_add ( 1 , Ordering :: SeqCst ) ;
1248
+ Box :: pin ( async move {
1249
+ tokio:: time:: sleep ( Duration :: from_millis ( 10 ) ) . await ;
1250
+ } )
1251
+ } )
1252
+ . unwrap ( ) ;
1253
+ f1. await ;
1254
+ f2. await ;
1255
+ assert_eq ! ( 2 , counter. load( Ordering :: SeqCst ) ) ;
1256
+
1257
+ cfg ( "pause" , "pause" ) . unwrap ( ) ;
1258
+ let ( tx, mut rx) = tokio:: sync:: mpsc:: channel ( 1 ) ;
1259
+ let handle = tokio:: spawn ( async move {
1260
+ async_fail_point ! ( "pause" ) ;
1261
+ tx. send ( ( ) ) . await . unwrap ( ) ;
1262
+ } ) ;
1263
+ tokio:: time:: timeout ( Duration :: from_millis ( 500 ) , rx. recv ( ) )
1264
+ . await
1265
+ . unwrap_err ( ) ;
1266
+ remove ( "pause" ) ;
1267
+ tokio:: time:: timeout ( Duration :: from_millis ( 500 ) , rx. recv ( ) )
1268
+ . await
1269
+ . unwrap ( ) ;
1270
+ handle. await . unwrap ( ) ;
1271
+ }
1065
1272
}
0 commit comments