Skip to content

Commit ea06e6c

Browse files
author
Xinye
committed
support async callback and pause
Signed-off-by: Xinye <[email protected]>
1 parent 5bc95a1 commit ea06e6c

File tree

2 files changed

+216
-3
lines changed

2 files changed

+216
-3
lines changed

Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,18 @@ edition = "2021"
1414
exclude = ["/.github/*", "/.travis.yml", "/appveyor.yml"]
1515

1616
[dependencies]
17+
futures = { version = "0.3", optional = true }
1718
log = { version = "0.4", features = ["std"] }
1819
once_cell = "1.9.0"
1920
rand = "0.8"
21+
tokio = { version = "1.32", features = [ "sync" ] }
22+
23+
[dev-dependencies]
24+
tokio = { version = "1.32", features = [ "sync", "rt-multi-thread", "time", "macros" ] }
2025

2126
[features]
2227
failpoints = []
28+
async = [ "futures" ]
2329

2430
[package.metadata.docs.rs]
2531
all-features = true

src/lib.rs

Lines changed: 210 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,162 @@ use std::sync::{Arc, Condvar, Mutex, MutexGuard, RwLock, TryLockError};
234234
use std::time::{Duration, Instant};
235235
use std::{env, thread};
236236

237+
#[cfg(feature = "async")]
238+
mod async_imp {
239+
use super::*;
240+
use futures::future::BoxFuture;
241+
242+
#[derive(Clone)]
243+
pub(crate) struct AsyncCallback(
244+
Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static>,
245+
);
246+
247+
impl Debug for AsyncCallback {
248+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249+
f.write_str("AsyncCallback()")
250+
}
251+
}
252+
253+
impl PartialEq for AsyncCallback {
254+
#[allow(clippy::vtable_address_comparisons)]
255+
fn eq(&self, other: &Self) -> bool {
256+
Arc::ptr_eq(&self.0, &other.0)
257+
}
258+
}
259+
260+
impl AsyncCallback {
261+
fn new(f: impl Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static) -> AsyncCallback {
262+
AsyncCallback(Arc::new(f))
263+
}
264+
265+
async fn run(&self) {
266+
let callback = &self.0;
267+
callback().await;
268+
}
269+
}
270+
271+
/// `fail_point` but with support for async callback.
272+
#[macro_export]
273+
#[cfg(all(feature = "failpoints", feature = "async"))]
274+
macro_rules! async_fail_point {
275+
($name:expr) => {{
276+
$crate::async_eval($name, |_| {
277+
panic!("Return is not supported for the fail point \"{}\"", $name);
278+
})
279+
.await;
280+
}};
281+
($name:expr, $e:expr) => {{
282+
if let Some(res) = $crate::async_eval($name, $e).await {
283+
return res;
284+
}
285+
}};
286+
($name:expr, $cond:expr, $e:expr) => {{
287+
if $cond {
288+
$crate::async_fail_point!($name, $e);
289+
}
290+
}};
291+
}
292+
293+
/// Configures an async callback to be triggered at the specified
294+
/// failpoint. If the failpoint is not implemented using
295+
/// `async_fail_point`, the execution will raise an exception.
296+
pub fn cfg_async_callback<S, F>(name: S, f: F) -> Result<(), String>
297+
where
298+
S: Into<String>,
299+
F: Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static,
300+
{
301+
let mut registry = REGISTRY.registry.write().unwrap();
302+
let p = registry
303+
.entry(name.into())
304+
.or_insert_with(|| Arc::new(FailPoint::new()));
305+
let action = Action::from_async_callback(f);
306+
let actions = vec![action];
307+
p.set_actions("callback", actions);
308+
Ok(())
309+
}
310+
311+
#[doc(hidden)]
312+
pub async fn async_eval<R, F: FnOnce(Option<String>) -> R>(name: &str, f: F) -> Option<R> {
313+
let p = {
314+
let registry = REGISTRY.registry.read().unwrap();
315+
match registry.get(name) {
316+
None => return None,
317+
Some(p) => p.clone(),
318+
}
319+
};
320+
p.async_eval(name).await.map(f)
321+
}
322+
323+
impl Action {
324+
#[cfg(feature = "async")]
325+
fn from_async_callback(
326+
f: impl Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static,
327+
) -> Action {
328+
let task = Task::CallbackAsync(AsyncCallback::new(f));
329+
Action {
330+
task,
331+
freq: 1.0,
332+
count: None,
333+
}
334+
}
335+
}
336+
337+
impl FailPoint {
338+
#[cfg_attr(feature = "cargo-clippy", allow(clippy::option_option))]
339+
async fn async_eval(&self, name: &str) -> Option<Option<String>> {
340+
let task = {
341+
let task = self
342+
.actions
343+
.read()
344+
.unwrap()
345+
.iter()
346+
.filter_map(Action::get_task)
347+
.next();
348+
match task {
349+
Some(Task::Pause) => {
350+
// let n = self.async_pause_notify.clone();
351+
self.async_pause_notify.notified().await;
352+
return None;
353+
}
354+
Some(t) => t,
355+
None => return None,
356+
}
357+
};
358+
359+
match task {
360+
Task::Off => {}
361+
Task::Return(s) => return Some(s),
362+
Task::Sleep(_) => panic!(
363+
"fail does not support async sleep, please use a async closure to sleep."
364+
),
365+
Task::Panic(msg) => match msg {
366+
Some(ref msg) => panic!("{}", msg),
367+
None => panic!("failpoint {} panic", name),
368+
},
369+
Task::Print(msg) => match msg {
370+
Some(ref msg) => log::info!("{}", msg),
371+
None => log::info!("failpoint {} executed.", name),
372+
},
373+
Task::Pause => unreachable!(),
374+
Task::Yield => thread::yield_now(),
375+
Task::Delay(_) => panic!(
376+
"fail does not support async delay, please use a async closure to sleep."
377+
),
378+
Task::Callback(f) => {
379+
f.run();
380+
}
381+
Task::CallbackAsync(f) => {
382+
f.run().await;
383+
}
384+
}
385+
None
386+
}
387+
}
388+
}
389+
390+
#[cfg(feature = "async")]
391+
pub use async_imp::*;
392+
237393
#[derive(Clone)]
238394
struct SyncCallback(Arc<dyn Fn() + Send + Sync>);
239395

@@ -282,6 +438,8 @@ enum Task {
282438
Delay(u64),
283439
/// Call callback function.
284440
Callback(SyncCallback),
441+
#[cfg(feature = "async")]
442+
CallbackAsync(async_imp::AsyncCallback),
285443
}
286444

287445
#[derive(Debug)]
@@ -433,6 +591,8 @@ impl FromStr for Action {
433591
struct FailPoint {
434592
pause: Mutex<bool>,
435593
pause_notifier: Condvar,
594+
#[cfg(feature = "async")]
595+
async_pause_notify: tokio::sync::Notify,
436596
actions: RwLock<Vec<Action>>,
437597
actions_str: RwLock<String>,
438598
}
@@ -443,13 +603,16 @@ impl FailPoint {
443603
FailPoint {
444604
pause: Mutex::new(false),
445605
pause_notifier: Condvar::new(),
606+
#[cfg(feature = "async")]
607+
async_pause_notify: tokio::sync::Notify::new(),
446608
actions: RwLock::default(),
447609
actions_str: RwLock::default(),
448610
}
449611
}
450612

451613
fn set_actions(&self, actions_str: &str, actions: Vec<Action>) {
452614
loop {
615+
self.async_pause_notify.notify_waiters();
453616
// TODO: maybe busy waiting here.
454617
match self.actions.try_write() {
455618
Err(TryLockError::WouldBlock) => {}
@@ -460,9 +623,11 @@ impl FailPoint {
460623
}
461624
Err(e) => panic!("unexpected poison: {:?}", e),
462625
}
463-
let mut guard = self.pause.lock().unwrap();
464-
*guard = false;
465-
self.pause_notifier.notify_all();
626+
{
627+
let mut guard = self.pause.lock().unwrap();
628+
*guard = false;
629+
self.pause_notifier.notify_all();
630+
}
466631
}
467632
}
468633

@@ -509,6 +674,7 @@ impl FailPoint {
509674
Task::Callback(f) => {
510675
f.run();
511676
}
677+
Task::CallbackAsync(_) => unreachable!(),
512678
}
513679
None
514680
}
@@ -1062,4 +1228,45 @@ mod tests {
10621228
assert_eq!(rx.recv_timeout(Duration::from_millis(500)).unwrap(), 0);
10631229
assert_eq!(f1(), 0);
10641230
}
1231+
1232+
#[cfg_attr(not(all(feature = "failpoints", feature = "async")), ignore)]
1233+
#[tokio::test]
1234+
async fn test_async_failpoint() {
1235+
use std::time::Duration;
1236+
1237+
let f1 = async {
1238+
async_fail_point!("cb");
1239+
};
1240+
let f2 = async {
1241+
async_fail_point!("cb");
1242+
};
1243+
1244+
let counter = Arc::new(AtomicUsize::new(0));
1245+
let counter2 = counter.clone();
1246+
cfg_async_callback("cb", move || {
1247+
counter2.fetch_add(1, Ordering::SeqCst);
1248+
Box::pin(async move {
1249+
tokio::time::sleep(Duration::from_millis(10)).await;
1250+
})
1251+
})
1252+
.unwrap();
1253+
f1.await;
1254+
f2.await;
1255+
assert_eq!(2, counter.load(Ordering::SeqCst));
1256+
1257+
cfg("pause", "pause").unwrap();
1258+
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
1259+
let handle = tokio::spawn(async move {
1260+
async_fail_point!("pause");
1261+
tx.send(()).await.unwrap();
1262+
});
1263+
tokio::time::timeout(Duration::from_millis(500), rx.recv())
1264+
.await
1265+
.unwrap_err();
1266+
remove("pause");
1267+
tokio::time::timeout(Duration::from_millis(500), rx.recv())
1268+
.await
1269+
.unwrap();
1270+
handle.await.unwrap();
1271+
}
10651272
}

0 commit comments

Comments
 (0)