diff --git a/src/lib.rs b/src/lib.rs index f23cc44..de15be5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -357,21 +357,11 @@ impl Action { } Some(self.task.clone()) } -} - -fn partition(s: &str, pattern: char) -> (&str, Option<&str>) { - let mut splits = s.splitn(2, pattern); - (splits.next().unwrap(), splits.next()) -} - -impl FromStr for Action { - type Err = String; - /// Parse an action. - /// - /// `s` should be in the format `[p%][cnt*]task[(args)]`, `p%` is the frequency, - /// `cnt` is the max times the action can be triggered. - fn from_str(s: &str) -> Result { + fn from_str_with_resolve_call( + s: &str, + resolve_call: Option<&ResolveCallFunc>, + ) -> Result { let mut remain = s.trim(); let mut args = None; // in case there is '%' in args, we need to parse it first. @@ -421,6 +411,14 @@ impl FromStr for Action { "pause" => Task::Pause, "yield" => Task::Yield, "delay" => Task::Delay(parse_timeout()?), + "call" => { + if let (Some(resolve), Some(arg)) = (resolve_call, args) { + let callback = SyncCallback(resolve(arg)); + Task::Callback(callback) + } else { + return Err(format!("call is unavailable in this context")); + } + } _ => return Err(format!("unrecognized command {:?}", remain)), }; @@ -428,6 +426,23 @@ impl FromStr for Action { } } +fn partition(s: &str, pattern: char) -> (&str, Option<&str>) { + let mut splits = s.splitn(2, pattern); + (splits.next().unwrap(), splits.next()) +} + +impl FromStr for Action { + type Err = String; + + /// Parse an action. + /// + /// `s` should be in the format `[p%][cnt*]task[(args)]`, `p%` is the frequency, + /// `cnt` is the max times the action can be triggered. + fn from_str(s: &str) -> Result { + Self::from_str_with_resolve_call(s, None) + } +} + #[cfg_attr(feature = "cargo-clippy", allow(clippy::mutex_atomic))] #[derive(Debug)] struct FailPoint { @@ -534,6 +549,31 @@ pub struct FailScenario<'a> { scenario_guard: MutexGuard<'a, &'static FailPointRegistry>, } +type ResolveCallFunc = Box Arc>; + +/// Customize behaviors setting up [`FailScenario`]. +#[non_exhaustive] +#[allow(missing_debug_implementations)] +pub struct SetupOptions { + /// Environment variable to use. Default: `"FAILPOINTS"`. + pub env_var_name: &'static str, + + /// Defines how to resolve `call(arg)` as a `task` in the + /// `FAILPOINTS` environment variable. The provided function + /// taks the `arg` string and returns a function to execute + /// as the task. + pub resolve_call: Option, +} + +impl Default for SetupOptions { + fn default() -> Self { + Self { + env_var_name: "FAILPOINTS", + resolve_call: None, + } + } +} + impl<'a> FailScenario<'a> { /// Set up the system for a fail points scenario. /// @@ -555,12 +595,18 @@ impl<'a> FailScenario<'a> { /// /// Panics if an action is not formatted correctly. pub fn setup() -> Self { + Self::setup_with_options(Default::default()) + } + + /// Similar to [`FailScenario::setup`] but takes an extra [`SetupOptions`] + /// for customization. + pub fn setup_with_options(options: SetupOptions) -> Self { // Cleanup first, in case of previous failed/panic'ed test scenarios. let scenario_guard = SCENARIO.lock().unwrap_or_else(|e| e.into_inner()); let mut registry = scenario_guard.registry.write().unwrap(); Self::cleanup(&mut registry); - let failpoints = match env::var("FAILPOINTS") { + let failpoints = match env::var(options.env_var_name) { Ok(s) => s, Err(VarError::NotPresent) => return Self { scenario_guard }, Err(e) => panic!("invalid failpoints: {:?}", e), @@ -574,7 +620,12 @@ impl<'a> FailScenario<'a> { match order { None => panic!("invalid failpoint: {:?}", cfg), Some(order) => { - if let Err(e) = set(&mut registry, name.to_owned(), order) { + if let Err(e) = set( + &mut registry, + name.to_owned(), + order, + options.resolve_call.as_ref(), + ) { panic!("unable to configure failpoint \"{}\": {}", name, e); } } @@ -669,13 +720,14 @@ pub fn eval) -> R>(name: &str, f: F) -> Option { /// times. /// /// The `FAILPOINTS` environment variable accepts this same syntax for its fail -/// point actions. +/// point actions. With [`SetupOptions::resolve_call`] and [`FailScenario::setup_with_options`], +/// `task` can also be `call(arg)` for customized behavior. /// /// A call to `cfg` with a particular fail point name overwrites any existing actions for /// that fail point, including those set via the `FAILPOINTS` environment variable. pub fn cfg>(name: S, actions: &str) -> Result<(), String> { let mut registry = REGISTRY.registry.write().unwrap(); - set(&mut registry, name.into(), actions) + set(&mut registry, name.into(), actions, None) } /// Configure the actions for a fail point at runtime. @@ -746,12 +798,13 @@ fn set( registry: &mut HashMap>, name: String, actions: &str, + resolve_call: Option<&ResolveCallFunc>, ) -> Result<(), String> { let actions_str = actions; // `actions` are in the format of `failpoint[->failpoint...]`. let actions = actions .split("->") - .map(Action::from_str) + .map(|a| Action::from_str_with_resolve_call(a, resolve_call)) .collect::>()?; // Please note that we can't figure out whether there is a failpoint named `name`, // so we may insert a failpoint that doesn't exist at all. @@ -1062,4 +1115,53 @@ mod tests { assert_eq!(rx.recv_timeout(Duration::from_millis(500)).unwrap(), 0); assert_eq!(f1(), 0); } + + #[test] + #[cfg_attr(not(feature = "failpoints"), ignore)] + fn test_setup_with_customized_env_name() { + let f1 = || { + fail_point!("setup_with_customized_env_name", |_| 1); + 0 + }; + env::set_var("FOO_FAILPOINTS", "setup_with_customized_env_name=return"); + let scenario = FailScenario::setup_with_options(SetupOptions { + env_var_name: "FOO_FAILPOINTS", + ..Default::default() + }); + assert_eq!(f1(), 1); + scenario.teardown(); + } + + #[test] + #[cfg_attr(not(feature = "failpoints"), ignore)] + fn test_setup_with_customized_resolve_call() { + env::set_var("FAILPOINTS", "customized_resolve_call=3*call(count)"); + + let count = Arc::new(AtomicUsize::new(0)); + let scenario = FailScenario::setup_with_options(SetupOptions { + resolve_call: Some({ + let count = count.clone(); + Box::new(move |arg| { + if arg == "count" { + let count = count.clone(); + return Arc::new(move || { + count.fetch_add(1, Ordering::AcqRel); + }); + } + panic!("unsupported call(): {}", arg); + }) + }), + ..Default::default() + }); + + let f = || { + fail_point!("customized_resolve_call"); + }; + for i in 0..5 { + assert_eq!(count.load(Ordering::Acquire), i.min(3)); + f(); + } + + scenario.teardown(); + } }