diff --git a/Cargo.lock b/Cargo.lock index d242b5922..a919bc748 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10257,6 +10257,48 @@ dependencies = [ "sp-runtime", ] +[[package]] +name = "pallet-rate-limiting" +version = "0.1.0" +dependencies = [ + "frame-benchmarking", + "frame-support", + "frame-system", + "parity-scale-codec", + "scale-info", + "serde", + "sp-core", + "sp-io", + "sp-runtime", + "sp-std", + "subtensor-runtime-common", +] + +[[package]] +name = "pallet-rate-limiting-rpc" +version = "0.1.0" +dependencies = [ + "jsonrpsee", + "pallet-rate-limiting-runtime-api", + "sp-api", + "sp-blockchain", + "sp-runtime", + "subtensor-runtime-common", +] + +[[package]] +name = "pallet-rate-limiting-runtime-api" +version = "0.1.0" +dependencies = [ + "pallet-rate-limiting", + "parity-scale-codec", + "scale-info", + "serde", + "sp-api", + "sp-std", + "subtensor-runtime-common", +] + [[package]] name = "pallet-recovery" version = "41.0.0" diff --git a/Cargo.toml b/Cargo.toml index 613900491..1faa26f23 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,8 @@ members = [ "common", "node", "pallets/*", + "pallets/rate-limiting/runtime-api", + "pallets/rate-limiting/rpc", "precompiles", "primitives/*", "runtime", @@ -59,6 +61,9 @@ pallet-subtensor = { path = "pallets/subtensor", default-features = false } pallet-subtensor-swap = { path = "pallets/swap", default-features = false } pallet-subtensor-swap-runtime-api = { path = "pallets/swap/runtime-api", default-features = false } pallet-subtensor-swap-rpc = { path = "pallets/swap/rpc", default-features = false } +pallet-rate-limiting = { path = "pallets/rate-limiting", default-features = false } +pallet-rate-limiting-runtime-api = { path = "pallets/rate-limiting/runtime-api", default-features = false } +pallet-rate-limiting-rpc = { path = "pallets/rate-limiting/rpc", default-features = false } procedural-fork = { path = "support/procedural-fork", default-features = false } safe-math = { path = "primitives/safe-math", default-features = false } share-pool = { path = "primitives/share-pool", default-features = false } diff --git a/pallets/rate-limiting/Cargo.toml b/pallets/rate-limiting/Cargo.toml new file mode 100644 index 000000000..344714562 --- /dev/null +++ b/pallets/rate-limiting/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "pallet-rate-limiting" +version = "0.1.0" +edition.workspace = true + +[lints] +workspace = true + +[package.metadata.docs.rs] +targets = ["x86_64-unknown-linux-gnu"] + +[dependencies] +codec = { workspace = true, features = ["derive"] } +frame-benchmarking = { workspace = true, optional = true } +frame-support.workspace = true +frame-system.workspace = true +scale-info = { workspace = true, features = ["derive"] } +serde = { workspace = true, features = ["derive"], optional = true } +sp-std.workspace = true +subtensor-runtime-common.workspace = true + +[dev-dependencies] +sp-core.workspace = true +sp-io.workspace = true +sp-runtime.workspace = true + +[features] +default = ["std"] +std = [ + "codec/std", + "frame-benchmarking?/std", + "frame-support/std", + "frame-system/std", + "scale-info/std", + "serde", + "sp-std/std", + "subtensor-runtime-common/std", +] +runtime-benchmarks = [ + "frame-benchmarking", +] +try-runtime = [ + "frame-support/try-runtime", + "frame-system/try-runtime", +] diff --git a/pallets/rate-limiting/rpc/Cargo.toml b/pallets/rate-limiting/rpc/Cargo.toml new file mode 100644 index 000000000..d5bf689e8 --- /dev/null +++ b/pallets/rate-limiting/rpc/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "pallet-rate-limiting-rpc" +version = "0.1.0" +description = "RPC interface for the rate limiting pallet" +edition.workspace = true + +[dependencies] +jsonrpsee = { workspace = true, features = ["client-core", "server", "macros"] } +sp-api.workspace = true +sp-blockchain.workspace = true +sp-runtime.workspace = true +pallet-rate-limiting-runtime-api.workspace = true +subtensor-runtime-common = { workspace = true, default-features = false } + +[features] +default = ["std"] +std = [ + "sp-api/std", + "sp-runtime/std", + "pallet-rate-limiting-runtime-api/std", + "subtensor-runtime-common/std", +] diff --git a/pallets/rate-limiting/rpc/src/lib.rs b/pallets/rate-limiting/rpc/src/lib.rs new file mode 100644 index 000000000..ca7452a7a --- /dev/null +++ b/pallets/rate-limiting/rpc/src/lib.rs @@ -0,0 +1,82 @@ +//! RPC interface for the rate limiting pallet. + +use jsonrpsee::{ + core::RpcResult, + proc_macros::rpc, + types::{ErrorObjectOwned, error::ErrorObject}, +}; +use sp_api::ProvideRuntimeApi; +use sp_blockchain::HeaderBackend; +use sp_runtime::traits::Block as BlockT; +use std::sync::Arc; + +pub use pallet_rate_limiting_runtime_api::{RateLimitRpcResponse, RateLimitingRuntimeApi}; + +#[rpc(client, server)] +pub trait RateLimitingRpcApi { + #[method(name = "rateLimiting_getRateLimit")] + fn get_rate_limit( + &self, + pallet: Vec, + extrinsic: Vec, + at: Option, + ) -> RpcResult>; +} + +/// Error type of this RPC api. +pub enum Error { + /// The call to runtime failed. + RuntimeError(String), +} + +impl From for ErrorObjectOwned { + fn from(e: Error) -> Self { + match e { + Error::RuntimeError(e) => ErrorObject::owned(1, e, None::<()>), + } + } +} + +impl From for i32 { + fn from(e: Error) -> i32 { + match e { + Error::RuntimeError(_) => 1, + } + } +} + +/// RPC implementation for the rate limiting pallet. +pub struct RateLimiting { + client: Arc, + _marker: std::marker::PhantomData, +} + +impl RateLimiting { + /// Creates a new instance of the rate limiting RPC helper. + pub fn new(client: Arc) -> Self { + Self { + client, + _marker: Default::default(), + } + } +} + +impl RateLimitingRpcApiServer<::Hash> for RateLimiting +where + Block: BlockT, + C: ProvideRuntimeApi + HeaderBackend + Send + Sync + 'static, + C::Api: RateLimitingRuntimeApi, +{ + fn get_rate_limit( + &self, + pallet: Vec, + extrinsic: Vec, + at: Option<::Hash>, + ) -> RpcResult> { + let api = self.client.runtime_api(); + let at = at.unwrap_or_else(|| self.client.info().best_hash); + + api.get_rate_limit(at, pallet, extrinsic) + .map_err(|e| Error::RuntimeError(format!("Unable to fetch rate limit: {e:?}")).into()) + } +} diff --git a/pallets/rate-limiting/runtime-api/Cargo.toml b/pallets/rate-limiting/runtime-api/Cargo.toml new file mode 100644 index 000000000..2847d865d --- /dev/null +++ b/pallets/rate-limiting/runtime-api/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "pallet-rate-limiting-runtime-api" +version = "0.1.0" +description = "Runtime API for the rate limiting pallet" +edition.workspace = true + +[dependencies] +codec = { workspace = true, features = ["derive"] } +scale-info = { workspace = true, features = ["derive"] } +sp-api.workspace = true +sp-std.workspace = true +pallet-rate-limiting.workspace = true +subtensor-runtime-common = { workspace = true, default-features = false } +serde = { workspace = true, features = ["derive"], optional = true } + +[features] +default = ["std"] +std = [ + "codec/std", + "scale-info/std", + "sp-api/std", + "sp-std/std", + "pallet-rate-limiting/std", + "subtensor-runtime-common/std", + "serde", +] diff --git a/pallets/rate-limiting/runtime-api/src/lib.rs b/pallets/rate-limiting/runtime-api/src/lib.rs new file mode 100644 index 000000000..98b55e9a2 --- /dev/null +++ b/pallets/rate-limiting/runtime-api/src/lib.rs @@ -0,0 +1,25 @@ +#![cfg_attr(not(feature = "std"), no_std)] + +use codec::{Decode, Encode}; +use pallet_rate_limiting::RateLimitKind; +use scale_info::TypeInfo; +use sp_std::vec::Vec; +use subtensor_runtime_common::BlockNumber; + +#[cfg(feature = "std")] +use serde::{Deserialize, Serialize}; + +#[cfg_attr(feature = "std", derive(Serialize, Deserialize))] +#[derive(Clone, Debug, Decode, Encode, Eq, PartialEq, TypeInfo)] +pub struct RateLimitRpcResponse { + pub global: Option>, + pub contextual: Vec<(Vec, RateLimitKind)>, + pub default_limit: BlockNumber, + pub resolved: Option, +} + +sp_api::decl_runtime_apis! { + pub trait RateLimitingRuntimeApi { + fn get_rate_limit(pallet: Vec, extrinsic: Vec) -> Option; + } +} diff --git a/pallets/rate-limiting/src/benchmarking.rs b/pallets/rate-limiting/src/benchmarking.rs new file mode 100644 index 000000000..65d547ab0 --- /dev/null +++ b/pallets/rate-limiting/src/benchmarking.rs @@ -0,0 +1,91 @@ +//! Benchmarking setup for pallet-rate-limiting +#![cfg(feature = "runtime-benchmarks")] +#![allow(clippy::arithmetic_side_effects)] + +use codec::Decode; +use frame_benchmarking::v2::*; +use frame_system::{RawOrigin, pallet_prelude::BlockNumberFor}; + +use super::*; + +pub trait BenchmarkHelper { + fn sample_call() -> Call; +} + +impl BenchmarkHelper for () +where + Call: Decode, +{ + fn sample_call() -> Call { + Decode::decode(&mut &[][..]).expect("Provide a call via BenchmarkHelper::sample_call") + } +} + +fn sample_call() -> Box<::RuntimeCall> +where + T::BenchmarkHelper: BenchmarkHelper<::RuntimeCall>, +{ + Box::new(T::BenchmarkHelper::sample_call()) +} + +#[benchmarks] +mod benchmarks { + use super::*; + + #[benchmark] + fn set_rate_limit() { + let call = sample_call::(); + let limit = RateLimitKind::>::Exact(BlockNumberFor::::from(10u32)); + let scope = ::LimitScopeResolver::context(call.as_ref()); + let identifier = + TransactionIdentifier::from_call::(call.as_ref()).expect("identifier"); + + #[extrinsic_call] + _(RawOrigin::Root, call, limit.clone()); + + let stored = Limits::::get(&identifier).expect("limit stored"); + match (scope, &stored) { + (Some(ref sc), RateLimit::Scoped(map)) => { + assert_eq!(map.get(sc), Some(&limit)); + } + (None, RateLimit::Global(kind)) | (Some(_), RateLimit::Global(kind)) => { + assert_eq!(kind, &limit); + } + (None, RateLimit::Scoped(map)) => { + assert!(map.values().any(|k| k == &limit)); + } + } + } + + #[benchmark] + fn clear_rate_limit() { + let call = sample_call::(); + let limit = RateLimitKind::>::Exact(BlockNumberFor::::from(10u32)); + let scope = ::LimitScopeResolver::context(call.as_ref()); + + // Pre-populate limit for benchmark call + let identifier = + TransactionIdentifier::from_call::(call.as_ref()).expect("identifier"); + match scope.clone() { + Some(sc) => Limits::::insert(identifier, RateLimit::scoped_single(sc, limit)), + None => Limits::::insert(identifier, RateLimit::global(limit)), + } + + #[extrinsic_call] + _(RawOrigin::Root, call); + + assert!(Limits::::get(identifier).is_none()); + } + + #[benchmark] + fn set_default_rate_limit() { + let block_span = BlockNumberFor::::from(10u32); + + #[extrinsic_call] + _(RawOrigin::Root, block_span); + + assert_eq!(DefaultLimit::::get(), block_span); + } + + impl_benchmark_test_suite!(Pallet, crate::mock::new_test_ext(), crate::mock::Test); +} diff --git a/pallets/rate-limiting/src/lib.rs b/pallets/rate-limiting/src/lib.rs new file mode 100644 index 000000000..0579249b3 --- /dev/null +++ b/pallets/rate-limiting/src/lib.rs @@ -0,0 +1,473 @@ +#![cfg_attr(not(feature = "std"), no_std)] + +//! Rate limiting for runtime calls with optional contextual restrictions. +//! +//! # Overview +//! +//! `pallet-rate-limiting` lets a runtime restrict how frequently particular calls can execute. +//! Limits are stored on-chain, keyed by the call's pallet/variant pair. Each entry can specify an +//! exact block span or defer to a configured default. The pallet exposes three extrinsics, +//! restricted by [`Config::AdminOrigin`], to manage this data: +//! +//! - [`set_rate_limit`](pallet::Pallet::set_rate_limit): assign a limit to an extrinsic by +//! supplying a [`RateLimitKind`] span. The pallet infers the *limit scope* (for example a +//! `netuid`) using [`Config::LimitScopeResolver`] and stores the configuration for that scope, or +//! globally when no scope is resolved. +//! - [`clear_rate_limit`](pallet::Pallet::clear_rate_limit): remove a stored limit for the scope +//! derived from the provided call (or the global entry when no scope resolves). +//! - [`set_default_rate_limit`](pallet::Pallet::set_default_rate_limit): set the global default +//! block span used by `RateLimitKind::Default` entries. +//! +//! The pallet also tracks the last block in which a rate-limited call was executed, per optional +//! *usage key*. A usage key may refine tracking beyond the limit scope (for example combining a +//! `netuid` with a hyperparameter name), so the two concepts are explicitly separated in the +//! configuration. +//! +//! Each storage map is namespaced by pallet instance; runtimes can deploy multiple independent +//! instances to manage distinct rate-limiting scopes. +//! +//! # Transaction extension +//! +//! Enforcement happens via [`RateLimitTransactionExtension`], which implements +//! `sp_runtime::traits::TransactionExtension`. The extension consults `Limits`, fetches the current +//! block, and decides whether the call is eligible. If successful, it returns metadata that causes +//! [`LastSeen`](pallet::LastSeen) to update after dispatch. A rejected call yields +//! `InvalidTransaction::Custom(1)`. +//! +//! To enable the extension, add it to your runtime's transaction extension tuple. For example: +//! +//! ```ignore +//! pub type TransactionExtensions = ( +//! // ... other extensions ... +//! pallet_rate_limiting::RateLimitTransactionExtension, +//! ); +//! ``` +//! +//! # Context resolvers +//! +//! The pallet relies on two resolvers, both implementing [`RateLimitContextResolver`]: +//! +//! - [`Config::LimitScopeResolver`], which determines how limits are stored (for example by +//! returning a `netuid`). When this resolver returns `None`, the configuration is stored as a +//! global fallback. +//! - [`Config::UsageResolver`], which decides how executions are tracked in +//! [`LastSeen`](pallet::LastSeen). This can refine the limit scope (for example by returning a +//! tuple of `(netuid, hyperparameter)`). +//! +//! Each resolver receives the call and may return `Some(identifier)` when scoping is required, or +//! `None` to use the global entry. Extrinsics such as +//! [`set_rate_limit`](pallet::Pallet::set_rate_limit) automatically consult these resolvers. +//! +//! ```ignore +//! pub struct WeightsContextResolver; +//! +//! // Limits are scoped per netuid. +//! pub struct ScopeResolver; +//! impl pallet_rate_limiting::RateLimitContextResolver for ScopeResolver { +//! fn context(call: &RuntimeCall) -> Option { +//! match call { +//! RuntimeCall::Subtensor(pallet_subtensor::Call::set_weights { netuid, .. }) => { +//! Some(*netuid) +//! } +//! _ => None, +//! } +//! } +//! } +//! +//! // Usage tracking distinguishes hyperparameter + netuid. +//! pub struct UsageResolver; +//! impl pallet_rate_limiting::RateLimitContextResolver +//! for UsageResolver +//! { +//! fn context(call: &RuntimeCall) -> Option<(NetUid, HyperParam)> { +//! match call { +//! RuntimeCall::Subtensor(pallet_subtensor::Call::set_hyperparam { +//! netuid, +//! hyper, +//! .. +//! }) => Some((*netuid, *hyper)), +//! _ => None, +//! } +//! } +//! } +//! +//! impl pallet_rate_limiting::Config for Runtime { +//! type RuntimeCall = RuntimeCall; +//! type LimitScope = NetUid; +//! type LimitScopeResolver = ScopeResolver; +//! type UsageKey = (NetUid, HyperParam); +//! type UsageResolver = UsageResolver; +//! type AdminOrigin = frame_system::EnsureRoot; +//! } +//! ``` + +#[cfg(feature = "runtime-benchmarks")] +pub use benchmarking::BenchmarkHelper; +pub use pallet::*; +pub use tx_extension::RateLimitTransactionExtension; +pub use types::{RateLimit, RateLimitContextResolver, RateLimitKind, TransactionIdentifier}; + +#[cfg(feature = "runtime-benchmarks")] +mod benchmarking; +mod tx_extension; +mod types; + +#[cfg(test)] +mod mock; + +#[cfg(test)] +mod tests; + +#[frame_support::pallet] +pub mod pallet { + use codec::Codec; + use frame_support::{ + pallet_prelude::*, + sp_runtime::traits::{Saturating, Zero}, + traits::{BuildGenesisConfig, EnsureOrigin, GetCallMetadata}, + }; + use frame_system::pallet_prelude::*; + use sp_std::{convert::TryFrom, marker::PhantomData, vec::Vec}; + + #[cfg(feature = "runtime-benchmarks")] + use crate::benchmarking::BenchmarkHelper as BenchmarkHelperTrait; + use crate::types::{RateLimit, RateLimitContextResolver, RateLimitKind, TransactionIdentifier}; + + /// Configuration trait for the rate limiting pallet. + #[pallet::config] + pub trait Config: frame_system::Config + where + BlockNumberFor: MaybeSerializeDeserialize, + { + /// The overarching runtime call type. + type RuntimeCall: Parameter + + Codec + + GetCallMetadata + + IsType<::RuntimeCall>; + + /// Origin permitted to configure rate limits. + type AdminOrigin: EnsureOrigin>; + + /// Scope identifier used to namespace stored rate limits. + type LimitScope: Parameter + Clone + PartialEq + Eq + Ord + MaybeSerializeDeserialize; + + /// Resolves the scope for the given runtime call when configuring limits. + type LimitScopeResolver: RateLimitContextResolver<>::RuntimeCall, Self::LimitScope>; + + /// Usage key tracked in [`LastSeen`] for rate-limited calls. + type UsageKey: Parameter + Clone + PartialEq + Eq + Ord + MaybeSerializeDeserialize; + + /// Resolves the usage key for the given runtime call when enforcing limits. + type UsageResolver: RateLimitContextResolver<>::RuntimeCall, Self::UsageKey>; + + /// Helper used to construct runtime calls for benchmarking. + #[cfg(feature = "runtime-benchmarks")] + type BenchmarkHelper: BenchmarkHelperTrait<>::RuntimeCall>; + } + + /// Storage mapping from transaction identifier to its configured rate limit. + #[pallet::storage] + #[pallet::getter(fn limits)] + pub type Limits, I: 'static = ()> = StorageMap< + _, + Blake2_128Concat, + TransactionIdentifier, + RateLimit<>::LimitScope, BlockNumberFor>, + OptionQuery, + >; + + /// Tracks when a transaction was last observed. + /// + /// The second key is `None` for global tracking and `Some(key)` for scoped usage tracking. + #[pallet::storage] + pub type LastSeen, I: 'static = ()> = StorageDoubleMap< + _, + Blake2_128Concat, + TransactionIdentifier, + Blake2_128Concat, + Option<>::UsageKey>, + BlockNumberFor, + OptionQuery, + >; + + /// Default block span applied when an extrinsic uses the default rate limit. + #[pallet::storage] + #[pallet::getter(fn default_limit)] + pub type DefaultLimit, I: 'static = ()> = + StorageValue<_, BlockNumberFor, ValueQuery>; + + /// Events emitted by the rate limiting pallet. + #[pallet::event] + #[pallet::generate_deposit(pub(super) fn deposit_event)] + pub enum Event, I: 'static = ()> { + /// A rate limit was set or updated. + RateLimitSet { + /// Identifier of the affected transaction. + transaction: TransactionIdentifier, + /// Limit scope to which the configuration applies, if any. + scope: Option<>::LimitScope>, + /// The rate limit policy applied to the transaction. + limit: RateLimitKind>, + /// Pallet name associated with the transaction. + pallet: Vec, + /// Extrinsic name associated with the transaction. + extrinsic: Vec, + }, + /// A rate limit was cleared. + RateLimitCleared { + /// Identifier of the affected transaction. + transaction: TransactionIdentifier, + /// Limit scope from which the configuration was cleared, if any. + scope: Option<>::LimitScope>, + /// Pallet name associated with the transaction. + pallet: Vec, + /// Extrinsic name associated with the transaction. + extrinsic: Vec, + }, + /// The default rate limit was set or updated. + DefaultRateLimitSet { + /// The new default limit expressed in blocks. + block_span: BlockNumberFor, + }, + } + + /// Errors that can occur while configuring rate limits. + #[pallet::error] + pub enum Error { + /// Failed to extract the pallet and extrinsic indices from the call. + InvalidRuntimeCall, + /// Attempted to remove a limit that is not present. + MissingRateLimit, + } + + #[pallet::genesis_config] + pub struct GenesisConfig, I: 'static = ()> { + pub default_limit: BlockNumberFor, + pub limits: Vec<( + TransactionIdentifier, + Option<>::LimitScope>, + RateLimitKind>, + )>, + } + + #[cfg(feature = "std")] + impl, I: 'static> Default for GenesisConfig { + fn default() -> Self { + Self { + default_limit: Zero::zero(), + limits: Vec::new(), + } + } + } + + #[pallet::genesis_build] + impl, I: 'static> BuildGenesisConfig for GenesisConfig { + fn build(&self) { + DefaultLimit::::put(self.default_limit); + + for (identifier, scope, kind) in &self.limits { + Limits::::mutate(identifier, |entry| match scope { + None => { + *entry = Some(RateLimit::global(*kind)); + } + Some(sc) => { + if let Some(config) = entry { + config.upsert_scope(sc.clone(), *kind); + } else { + *entry = Some(RateLimit::scoped_single(sc.clone(), *kind)); + } + } + }); + } + } + } + + #[pallet::pallet] + #[pallet::without_storage_info] + pub struct Pallet(PhantomData<(T, I)>); + + impl, I: 'static> Pallet { + /// Returns `true` when the given transaction identifier passes its configured rate limit + /// within the provided usage scope. + pub fn is_within_limit( + identifier: &TransactionIdentifier, + scope: &Option<>::LimitScope>, + usage_key: &Option<>::UsageKey>, + ) -> Result { + let Some(block_span) = Self::resolved_limit(identifier, scope) else { + return Ok(true); + }; + + let current = frame_system::Pallet::::block_number(); + + if let Some(last) = LastSeen::::get(identifier, usage_key) { + let delta = current.saturating_sub(last); + if delta < block_span { + return Ok(false); + } + } + + Ok(true) + } + + pub(crate) fn resolved_limit( + identifier: &TransactionIdentifier, + scope: &Option<>::LimitScope>, + ) -> Option> { + let config = Limits::::get(identifier)?; + let kind = config.kind_for(scope.as_ref())?; + Some(match *kind { + RateLimitKind::Default => DefaultLimit::::get(), + RateLimitKind::Exact(block_span) => block_span, + }) + } + + /// Returns the configured limit for the specified pallet/extrinsic names, if any. + pub fn limit_for_call_names( + pallet_name: &str, + extrinsic_name: &str, + scope: Option<>::LimitScope>, + ) -> Option>> { + let identifier = Self::identifier_for_call_names(pallet_name, extrinsic_name)?; + Limits::::get(&identifier) + .and_then(|config| config.kind_for(scope.as_ref()).copied()) + } + + /// Returns the resolved block span for the specified pallet/extrinsic names, if any. + pub fn resolved_limit_for_call_names( + pallet_name: &str, + extrinsic_name: &str, + scope: Option<>::LimitScope>, + ) -> Option> { + let identifier = Self::identifier_for_call_names(pallet_name, extrinsic_name)?; + Self::resolved_limit(&identifier, &scope) + } + + fn identifier_for_call_names( + pallet_name: &str, + extrinsic_name: &str, + ) -> Option { + let modules = >::RuntimeCall::get_module_names(); + let pallet_pos = modules.iter().position(|name| *name == pallet_name)?; + let call_names = >::RuntimeCall::get_call_names(pallet_name); + let extrinsic_pos = call_names.iter().position(|name| *name == extrinsic_name)?; + let pallet_index = u8::try_from(pallet_pos).ok()?; + let extrinsic_index = u8::try_from(extrinsic_pos).ok()?; + Some(TransactionIdentifier::new(pallet_index, extrinsic_index)) + } + } + + #[pallet::call] + impl, I: 'static> Pallet { + /// Sets the rate limit configuration for the given call. + /// + /// The supplied `call` is only used to derive the pallet and extrinsic indices; **any + /// arguments embedded in the call are ignored**. The applicable scope is discovered via + /// [`Config::LimitScopeResolver`]. When a scope resolves, the configuration is stored + /// against that scope; otherwise the global entry is updated. + #[pallet::call_index(0)] + #[pallet::weight(T::DbWeight::get().reads_writes(1, 1))] + pub fn set_rate_limit( + origin: OriginFor, + call: Box<>::RuntimeCall>, + limit: RateLimitKind>, + ) -> DispatchResult { + T::AdminOrigin::ensure_origin(origin)?; + + let identifier = TransactionIdentifier::from_call::(call.as_ref())?; + let scope = >::LimitScopeResolver::context(call.as_ref()); + let scope_for_event = scope.clone(); + + if let Some(ref sc) = scope { + Limits::::mutate(&identifier, |slot| match slot { + Some(config) => config.upsert_scope(sc.clone(), limit), + None => *slot = Some(RateLimit::scoped_single(sc.clone(), limit)), + }); + } else { + Limits::::insert(&identifier, RateLimit::global(limit)); + } + + let (pallet_name, extrinsic_name) = identifier.names::()?; + let pallet = Vec::from(pallet_name.as_bytes()); + let extrinsic = Vec::from(extrinsic_name.as_bytes()); + + Self::deposit_event(Event::RateLimitSet { + transaction: identifier, + scope: scope_for_event, + limit, + pallet, + extrinsic, + }); + Ok(()) + } + + /// Clears the rate limit for the given call, if present. + /// + /// The supplied `call` is only used to derive the pallet and extrinsic indices; **any + /// arguments embedded in the call are ignored**. The configuration scope is determined via + /// [`Config::LimitScopeResolver`]. When no scope resolves, the global entry is cleared. + #[pallet::call_index(1)] + #[pallet::weight(T::DbWeight::get().reads_writes(1, 1))] + pub fn clear_rate_limit( + origin: OriginFor, + call: Box<>::RuntimeCall>, + ) -> DispatchResult { + T::AdminOrigin::ensure_origin(origin)?; + + let identifier = TransactionIdentifier::from_call::(call.as_ref())?; + let scope = >::LimitScopeResolver::context(call.as_ref()); + + let (pallet_name, extrinsic_name) = identifier.names::()?; + let pallet = Vec::from(pallet_name.as_bytes()); + let extrinsic = Vec::from(extrinsic_name.as_bytes()); + + let mut removed = false; + Limits::::mutate_exists(&identifier, |maybe_config| { + if let Some(config) = maybe_config { + match (&scope, config) { + (None, _) => { + removed = true; + *maybe_config = None; + } + (Some(sc), RateLimit::Scoped(map)) => { + if map.remove(sc).is_some() { + removed = true; + if map.is_empty() { + *maybe_config = None; + } + } + } + (Some(_), RateLimit::Global(_)) => {} + } + } + }); + + ensure!(removed, Error::::MissingRateLimit); + + Self::deposit_event(Event::RateLimitCleared { + transaction: identifier, + scope, + pallet, + extrinsic, + }); + + Ok(()) + } + + /// Sets the default rate limit in blocks applied to calls configured to use it. + #[pallet::call_index(2)] + #[pallet::weight(T::DbWeight::get().writes(1))] + pub fn set_default_rate_limit( + origin: OriginFor, + block_span: BlockNumberFor, + ) -> DispatchResult { + T::AdminOrigin::ensure_origin(origin)?; + + DefaultLimit::::put(block_span); + + Self::deposit_event(Event::DefaultRateLimitSet { block_span }); + + Ok(()) + } + } +} diff --git a/pallets/rate-limiting/src/mock.rs b/pallets/rate-limiting/src/mock.rs new file mode 100644 index 000000000..aec00b45b --- /dev/null +++ b/pallets/rate-limiting/src/mock.rs @@ -0,0 +1,127 @@ +#![allow(dead_code)] + +use core::convert::TryInto; + +use frame_support::{ + derive_impl, + sp_runtime::{ + BuildStorage, + traits::{BlakeTwo256, IdentityLookup}, + }, + traits::{ConstU16, ConstU32, ConstU64, Everything}, +}; +use frame_system::EnsureRoot; +use sp_core::H256; +use sp_io::TestExternalities; +use sp_std::vec::Vec; + +use crate as pallet_rate_limiting; +use crate::TransactionIdentifier; + +pub type UncheckedExtrinsic = frame_system::mocking::MockUncheckedExtrinsic; +pub type Block = frame_system::mocking::MockBlock; + +frame_support::construct_runtime!( + pub enum Test { + System: frame_system, + RateLimiting: pallet_rate_limiting, + } +); + +#[derive_impl(frame_system::config_preludes::TestDefaultConfig)] +impl frame_system::Config for Test { + type BaseCallFilter = Everything; + type BlockWeights = (); + type BlockLength = (); + type DbWeight = (); + type RuntimeOrigin = RuntimeOrigin; + type RuntimeCall = RuntimeCall; + type Nonce = u64; + type Hash = H256; + type Hashing = BlakeTwo256; + type AccountId = u64; + type Lookup = IdentityLookup; + type RuntimeEvent = RuntimeEvent; + type BlockHashCount = ConstU64<250>; + type Version = (); + type PalletInfo = PalletInfo; + type AccountData = (); + type OnNewAccount = (); + type OnKilledAccount = (); + type SystemWeightInfo = (); + type SS58Prefix = ConstU16<42>; + type OnSetCode = (); + type MaxConsumers = ConstU32<16>; + type Block = Block; +} + +pub type LimitScope = u16; +pub type UsageKey = u16; + +pub struct TestScopeResolver; +pub struct TestUsageResolver; + +impl pallet_rate_limiting::RateLimitContextResolver for TestScopeResolver { + fn context(call: &RuntimeCall) -> Option { + match call { + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span }) => { + (*block_span).try_into().ok() + } + RuntimeCall::RateLimiting(_) => Some(1), + _ => None, + } + } +} + +impl pallet_rate_limiting::RateLimitContextResolver for TestUsageResolver { + fn context(call: &RuntimeCall) -> Option { + match call { + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span }) => { + (*block_span).try_into().ok() + } + RuntimeCall::RateLimiting(_) => Some(1), + _ => None, + } + } +} + +impl pallet_rate_limiting::Config for Test { + type RuntimeCall = RuntimeCall; + type LimitScope = LimitScope; + type LimitScopeResolver = TestScopeResolver; + type UsageKey = UsageKey; + type UsageResolver = TestUsageResolver; + type AdminOrigin = EnsureRoot; + #[cfg(feature = "runtime-benchmarks")] + type BenchmarkHelper = BenchHelper; +} + +#[cfg(feature = "runtime-benchmarks")] +pub struct BenchHelper; + +#[cfg(feature = "runtime-benchmarks")] +impl crate::BenchmarkHelper for BenchHelper { + fn sample_call() -> RuntimeCall { + RuntimeCall::System(frame_system::Call::remark { remark: Vec::new() }) + } +} + +pub type RateLimitingCall = crate::Call; + +pub fn new_test_ext() -> TestExternalities { + let storage = frame_system::GenesisConfig::::default() + .build_storage() + .expect("genesis build succeeds"); + + let mut ext = TestExternalities::new(storage); + ext.execute_with(|| System::set_block_number(1)); + ext +} + +pub(crate) fn identifier_for(call: &RuntimeCall) -> TransactionIdentifier { + TransactionIdentifier::from_call::(call).expect("identifier for call") +} + +pub(crate) fn pop_last_event() -> RuntimeEvent { + System::events().pop().expect("event expected").event +} diff --git a/pallets/rate-limiting/src/tests.rs b/pallets/rate-limiting/src/tests.rs new file mode 100644 index 000000000..f02c2c52b --- /dev/null +++ b/pallets/rate-limiting/src/tests.rs @@ -0,0 +1,414 @@ +use frame_support::{assert_noop, assert_ok, error::BadOrigin}; +use sp_std::{collections::btree_map::BTreeMap, vec::Vec}; + +use crate::{DefaultLimit, LastSeen, Limits, RateLimit, RateLimitKind, mock::*, pallet::Error}; + +#[test] +fn limit_for_call_names_returns_none_if_not_set() { + new_test_ext().execute_with(|| { + assert!( + RateLimiting::limit_for_call_names("RateLimiting", "set_default_rate_limit", None) + .is_none() + ); + }); +} + +#[test] +fn limit_for_call_names_returns_stored_limit() { + new_test_ext().execute_with(|| { + let call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 0 }); + let identifier = identifier_for(&call); + Limits::::insert(identifier, RateLimit::global(RateLimitKind::Exact(7))); + + let fetched = + RateLimiting::limit_for_call_names("RateLimiting", "set_default_rate_limit", None) + .expect("limit should exist"); + assert_eq!(fetched, RateLimitKind::Exact(7)); + }); +} + +#[test] +fn limit_for_call_names_prefers_scope_specific_limit() { + new_test_ext().execute_with(|| { + let call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 0 }); + let identifier = identifier_for(&call); + Limits::::insert( + identifier, + RateLimit::scoped_single(5u16, RateLimitKind::Exact(8)), + ); + + let fetched = + RateLimiting::limit_for_call_names("RateLimiting", "set_default_rate_limit", Some(5)) + .expect("limit should exist"); + assert_eq!(fetched, RateLimitKind::Exact(8)); + + assert!( + RateLimiting::limit_for_call_names("RateLimiting", "set_default_rate_limit", Some(1)) + .is_none() + ); + }); +} + +#[test] +fn resolved_limit_for_call_names_resolves_default_value() { + new_test_ext().execute_with(|| { + DefaultLimit::::put(3); + let call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 0 }); + let identifier = identifier_for(&call); + Limits::::insert(identifier, RateLimit::global(RateLimitKind::Default)); + + let resolved = RateLimiting::resolved_limit_for_call_names( + "RateLimiting", + "set_default_rate_limit", + None, + ) + .expect("resolved limit"); + assert_eq!(resolved, 3); + }); +} + +#[test] +fn resolved_limit_for_call_names_prefers_scope_specific_value() { + new_test_ext().execute_with(|| { + let call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 0 }); + let identifier = identifier_for(&call); + let mut map = BTreeMap::new(); + map.insert(6u16, RateLimitKind::Exact(9)); + map.insert(2u16, RateLimitKind::Exact(4)); + Limits::::insert(identifier, RateLimit::Scoped(map)); + + let resolved = RateLimiting::resolved_limit_for_call_names( + "RateLimiting", + "set_default_rate_limit", + Some(6), + ) + .expect("resolved limit"); + assert_eq!(resolved, 9); + + assert!( + RateLimiting::resolved_limit_for_call_names( + "RateLimiting", + "set_default_rate_limit", + Some(1), + ) + .is_none() + ); + }); +} + +#[test] +fn resolved_limit_for_call_names_returns_none_when_unset() { + new_test_ext().execute_with(|| { + assert!( + RateLimiting::resolved_limit_for_call_names( + "RateLimiting", + "set_default_rate_limit", + None, + ) + .is_none() + ); + }); +} + +#[test] +fn is_within_limit_is_true_when_no_limit() { + new_test_ext().execute_with(|| { + let call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 0 }); + let identifier = identifier_for(&call); + + let result = RateLimiting::is_within_limit(&identifier, &None, &None); + assert_eq!(result.expect("no error expected"), true); + }); +} + +#[test] +fn is_within_limit_false_when_rate_limited() { + new_test_ext().execute_with(|| { + let call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 0 }); + let identifier = identifier_for(&call); + Limits::::insert( + identifier, + RateLimit::scoped_single(1 as LimitScope, RateLimitKind::Exact(5)), + ); + LastSeen::::insert(identifier, Some(1 as UsageKey), 9); + + System::set_block_number(13); + + let within = RateLimiting::is_within_limit( + &identifier, + &Some(1 as LimitScope), + &Some(1 as UsageKey), + ) + .expect("call succeeds"); + assert!(!within); + }); +} + +#[test] +fn is_within_limit_true_after_required_span() { + new_test_ext().execute_with(|| { + let call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 0 }); + let identifier = identifier_for(&call); + Limits::::insert( + identifier, + RateLimit::scoped_single(2 as LimitScope, RateLimitKind::Exact(5)), + ); + LastSeen::::insert(identifier, Some(2 as UsageKey), 10); + + System::set_block_number(20); + + let within = RateLimiting::is_within_limit( + &identifier, + &Some(2 as LimitScope), + &Some(2 as UsageKey), + ) + .expect("call succeeds"); + assert!(within); + }); +} + +#[test] +fn set_rate_limit_updates_storage_and_emits_event() { + new_test_ext().execute_with(|| { + System::reset_events(); + + let target_call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 0 }); + let limit = RateLimitKind::Exact(9); + + assert_ok!(RateLimiting::set_rate_limit( + RuntimeOrigin::root(), + Box::new(target_call.clone()), + limit, + )); + + let identifier = identifier_for(&target_call); + assert_eq!( + Limits::::get(identifier), + Some(RateLimit::scoped_single(0, limit)) + ); + + match pop_last_event() { + RuntimeEvent::RateLimiting(crate::pallet::Event::RateLimitSet { + transaction, + scope, + limit: emitted_limit, + pallet, + extrinsic, + }) => { + assert_eq!(transaction, identifier); + assert_eq!(scope, Some(0)); + assert_eq!(emitted_limit, limit); + assert_eq!(pallet, b"RateLimiting".to_vec()); + assert_eq!(extrinsic, b"set_default_rate_limit".to_vec()); + } + other => panic!("unexpected event: {:?}", other), + } + }); +} + +#[test] +fn set_rate_limit_stores_global_when_scope_absent() { + new_test_ext().execute_with(|| { + System::reset_events(); + + let target_call = + RuntimeCall::System(frame_system::Call::::remark { remark: Vec::new() }); + let limit = RateLimitKind::Exact(11); + + assert_ok!(RateLimiting::set_rate_limit( + RuntimeOrigin::root(), + Box::new(target_call.clone()), + limit, + )); + + let identifier = identifier_for(&target_call); + assert_eq!( + Limits::::get(identifier), + Some(RateLimit::global(limit)) + ); + + match pop_last_event() { + RuntimeEvent::RateLimiting(crate::pallet::Event::RateLimitSet { + transaction, + scope, + limit: emitted_limit, + pallet, + extrinsic, + }) => { + assert_eq!(transaction, identifier); + assert_eq!(scope, None); + assert_eq!(emitted_limit, limit); + assert_eq!(pallet, b"System".to_vec()); + assert_eq!(extrinsic, b"remark".to_vec()); + } + other => panic!("unexpected event: {:?}", other), + } + }); +} + +#[test] +fn set_rate_limit_requires_root() { + new_test_ext().execute_with(|| { + let target_call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 0 }); + + assert_noop!( + RateLimiting::set_rate_limit( + RuntimeOrigin::signed(1), + Box::new(target_call), + RateLimitKind::Exact(1), + ), + BadOrigin + ); + }); +} + +#[test] +fn set_rate_limit_accepts_default_variant() { + new_test_ext().execute_with(|| { + let target_call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 0 }); + + assert_ok!(RateLimiting::set_rate_limit( + RuntimeOrigin::root(), + Box::new(target_call.clone()), + RateLimitKind::Default, + )); + + let identifier = identifier_for(&target_call); + assert_eq!( + Limits::::get(identifier), + Some(RateLimit::scoped_single(0, RateLimitKind::Default)) + ); + }); +} + +#[test] +fn clear_rate_limit_removes_entry_and_emits_event() { + new_test_ext().execute_with(|| { + System::reset_events(); + + let target_call = + RuntimeCall::System(frame_system::Call::::remark { remark: Vec::new() }); + let identifier = identifier_for(&target_call); + Limits::::insert(identifier, RateLimit::global(RateLimitKind::Exact(4))); + + assert_ok!(RateLimiting::clear_rate_limit( + RuntimeOrigin::root(), + Box::new(target_call.clone()), + )); + + assert!(Limits::::get(identifier).is_none()); + + match pop_last_event() { + RuntimeEvent::RateLimiting(crate::pallet::Event::RateLimitCleared { + transaction, + scope, + pallet, + extrinsic, + }) => { + assert_eq!(transaction, identifier); + assert_eq!(scope, None); + assert_eq!(pallet, b"System".to_vec()); + assert_eq!(extrinsic, b"remark".to_vec()); + } + other => panic!("unexpected event: {:?}", other), + } + }); +} + +#[test] +fn clear_rate_limit_fails_when_missing() { + new_test_ext().execute_with(|| { + let target_call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 0 }); + + assert_noop!( + RateLimiting::clear_rate_limit(RuntimeOrigin::root(), Box::new(target_call)), + Error::::MissingRateLimit + ); + }); +} + +#[test] +fn clear_rate_limit_removes_only_selected_scope() { + new_test_ext().execute_with(|| { + System::reset_events(); + + let base_call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 0 }); + let identifier = identifier_for(&base_call); + let mut map = BTreeMap::new(); + map.insert(9u16, RateLimitKind::Exact(7)); + map.insert(10u16, RateLimitKind::Exact(5)); + Limits::::insert(identifier, RateLimit::Scoped(map)); + + let scoped_call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 9 }); + + assert_ok!(RateLimiting::clear_rate_limit( + RuntimeOrigin::root(), + Box::new(scoped_call.clone()), + )); + + let config = Limits::::get(identifier).expect("config remains"); + assert!(config.kind_for(Some(&9u16)).is_none()); + assert_eq!( + config.kind_for(Some(&10u16)).copied(), + Some(RateLimitKind::Exact(5)) + ); + + match pop_last_event() { + RuntimeEvent::RateLimiting(crate::pallet::Event::RateLimitCleared { + transaction, + scope, + .. + }) => { + assert_eq!(transaction, identifier); + assert_eq!(scope, Some(9)); + } + other => panic!("unexpected event: {:?}", other), + } + }); +} + +#[test] +fn set_default_rate_limit_updates_storage_and_emits_event() { + new_test_ext().execute_with(|| { + System::reset_events(); + + assert_ok!(RateLimiting::set_default_rate_limit( + RuntimeOrigin::root(), + 42 + )); + + assert_eq!(DefaultLimit::::get(), 42); + + match pop_last_event() { + RuntimeEvent::RateLimiting(crate::pallet::Event::DefaultRateLimitSet { + block_span, + }) => { + assert_eq!(block_span, 42); + } + other => panic!("unexpected event: {:?}", other), + } + }); +} + +#[test] +fn set_default_rate_limit_requires_root() { + new_test_ext().execute_with(|| { + assert_noop!( + RateLimiting::set_default_rate_limit(RuntimeOrigin::signed(1), 5), + BadOrigin + ); + }); +} diff --git a/pallets/rate-limiting/src/tx_extension.rs b/pallets/rate-limiting/src/tx_extension.rs new file mode 100644 index 000000000..5276f1f39 --- /dev/null +++ b/pallets/rate-limiting/src/tx_extension.rs @@ -0,0 +1,342 @@ +use codec::{Decode, DecodeWithMemTracking, Encode}; +use frame_support::{ + dispatch::{DispatchInfo, DispatchResult, PostDispatchInfo}, + pallet_prelude::Weight, + sp_runtime::{ + traits::{ + DispatchInfoOf, DispatchOriginOf, Dispatchable, Implication, TransactionExtension, + ValidateResult, Zero, + }, + transaction_validity::{ + InvalidTransaction, TransactionSource, TransactionValidityError, ValidTransaction, + }, + }, +}; +use scale_info::TypeInfo; +use sp_std::{marker::PhantomData, result::Result}; + +use crate::{ + Config, LastSeen, Pallet, + types::{RateLimitContextResolver, TransactionIdentifier}, +}; + +/// Identifier returned in the transaction metadata for the rate limiting extension. +const IDENTIFIER: &str = "RateLimitTransactionExtension"; + +/// Custom error code used to signal a rate limit violation. +const RATE_LIMIT_DENIED: u8 = 1; + +/// Transaction extension that enforces pallet rate limiting rules. +#[derive(Default, Encode, Decode, DecodeWithMemTracking, TypeInfo)] +pub struct RateLimitTransactionExtension(PhantomData<(T, I)>) +where + T: Config + Send + Sync + TypeInfo, + I: 'static + TypeInfo; + +impl Clone for RateLimitTransactionExtension +where + T: Config + Send + Sync + TypeInfo, + I: 'static + TypeInfo, +{ + fn clone(&self) -> Self { + Self(PhantomData) + } +} + +impl PartialEq for RateLimitTransactionExtension +where + T: Config + Send + Sync + TypeInfo, + I: 'static + TypeInfo, +{ + fn eq(&self, _other: &Self) -> bool { + true + } +} + +impl Eq for RateLimitTransactionExtension +where + T: Config + Send + Sync + TypeInfo, + I: 'static + TypeInfo, +{ +} + +impl core::fmt::Debug for RateLimitTransactionExtension +where + T: Config + Send + Sync + TypeInfo, + I: 'static + TypeInfo, +{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(IDENTIFIER) + } +} + +impl TransactionExtension<>::RuntimeCall> + for RateLimitTransactionExtension +where + T: Config + Send + Sync + TypeInfo, + I: 'static + TypeInfo + Send + Sync, + >::RuntimeCall: Dispatchable, +{ + const IDENTIFIER: &'static str = IDENTIFIER; + + type Implicit = (); + type Val = Option<(TransactionIdentifier, Option<>::UsageKey>)>; + type Pre = Option<(TransactionIdentifier, Option<>::UsageKey>)>; + + fn weight(&self, _call: &>::RuntimeCall) -> Weight { + Weight::zero() + } + + fn validate( + &self, + origin: DispatchOriginOf<>::RuntimeCall>, + call: &>::RuntimeCall, + _info: &DispatchInfoOf<>::RuntimeCall>, + _len: usize, + _self_implicit: Self::Implicit, + _inherited_implication: &impl Implication, + _source: TransactionSource, + ) -> ValidateResult>::RuntimeCall> { + let identifier = match TransactionIdentifier::from_call::(call) { + Ok(identifier) => identifier, + Err(_) => return Err(TransactionValidityError::Invalid(InvalidTransaction::Call)), + }; + + let scope = >::LimitScopeResolver::context(call); + let usage = >::UsageResolver::context(call); + + let Some(block_span) = Pallet::::resolved_limit(&identifier, &scope) else { + return Ok((ValidTransaction::default(), None, origin)); + }; + + if block_span.is_zero() { + return Ok((ValidTransaction::default(), None, origin)); + } + + let within_limit = Pallet::::is_within_limit(&identifier, &scope, &usage) + .map_err(|_| TransactionValidityError::Invalid(InvalidTransaction::Call))?; + + if !within_limit { + return Err(TransactionValidityError::Invalid( + InvalidTransaction::Custom(RATE_LIMIT_DENIED), + )); + } + + Ok(( + ValidTransaction::default(), + Some((identifier, usage)), + origin, + )) + } + + fn prepare( + self, + val: Self::Val, + _origin: &DispatchOriginOf<>::RuntimeCall>, + _call: &>::RuntimeCall, + _info: &DispatchInfoOf<>::RuntimeCall>, + _len: usize, + ) -> Result { + Ok(val) + } + + fn post_dispatch( + pre: Self::Pre, + _info: &DispatchInfoOf<>::RuntimeCall>, + _post_info: &mut PostDispatchInfo, + _len: usize, + result: &DispatchResult, + ) -> Result<(), TransactionValidityError> { + if result.is_ok() { + if let Some((identifier, usage)) = pre { + let block_number = frame_system::Pallet::::block_number(); + LastSeen::::insert(&identifier, usage, block_number); + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use codec::Encode; + use frame_support::dispatch::{GetDispatchInfo, PostDispatchInfo}; + use sp_runtime::{ + traits::{TransactionExtension, TxBaseImplication}, + transaction_validity::{InvalidTransaction, TransactionSource, TransactionValidityError}, + }; + + use crate::{ + LastSeen, Limits, + types::{RateLimit, RateLimitKind, TransactionIdentifier}, + }; + + use super::*; + use crate::mock::*; + + fn remark_call() -> RuntimeCall { + RuntimeCall::System(frame_system::Call::::remark { remark: Vec::new() }) + } + + fn new_tx_extension() -> RateLimitTransactionExtension { + RateLimitTransactionExtension(Default::default()) + } + + fn validate_with_tx_extension( + extension: &RateLimitTransactionExtension, + call: &RuntimeCall, + ) -> Result< + ( + sp_runtime::transaction_validity::ValidTransaction, + Option<(TransactionIdentifier, Option)>, + RuntimeOrigin, + ), + TransactionValidityError, + > { + let info = call.get_dispatch_info(); + let len = call.encode().len(); + extension.validate( + RuntimeOrigin::signed(42), + call, + &info, + len, + (), + &TxBaseImplication(()), + TransactionSource::External, + ) + } + + #[test] + fn tx_extension_allows_calls_without_limit() { + new_test_ext().execute_with(|| { + let extension = new_tx_extension(); + let call = remark_call(); + + let (_valid, val, _origin) = + validate_with_tx_extension(&extension, &call).expect("valid"); + assert!(val.is_none()); + + let info = call.get_dispatch_info(); + let len = call.encode().len(); + let origin_for_prepare = RuntimeOrigin::signed(42); + let pre = extension + .clone() + .prepare(val.clone(), &origin_for_prepare, &call, &info, len) + .expect("prepare succeeds"); + + let mut post = PostDispatchInfo::default(); + RateLimitTransactionExtension::::post_dispatch( + pre, + &info, + &mut post, + len, + &Ok(()), + ) + .expect("post_dispatch succeeds"); + + let identifier = identifier_for(&call); + assert_eq!( + LastSeen::::get(identifier, None::), + None + ); + }); + } + + #[test] + fn tx_extension_records_last_seen_for_successful_call() { + new_test_ext().execute_with(|| { + let extension = new_tx_extension(); + let call = remark_call(); + let identifier = identifier_for(&call); + Limits::::insert(identifier, RateLimit::global(RateLimitKind::Exact(5))); + + System::set_block_number(10); + + let (_valid, val, _) = validate_with_tx_extension(&extension, &call).expect("valid"); + assert!(val.is_some()); + + let info = call.get_dispatch_info(); + let len = call.encode().len(); + let origin_for_prepare = RuntimeOrigin::signed(42); + let pre = extension + .clone() + .prepare(val.clone(), &origin_for_prepare, &call, &info, len) + .expect("prepare succeeds"); + + let mut post = PostDispatchInfo::default(); + RateLimitTransactionExtension::::post_dispatch( + pre, + &info, + &mut post, + len, + &Ok(()), + ) + .expect("post_dispatch succeeds"); + + assert_eq!( + LastSeen::::get(identifier, None::), + Some(10) + ); + }); + } + + #[test] + fn tx_extension_rejects_when_call_occurs_too_soon() { + new_test_ext().execute_with(|| { + let extension = new_tx_extension(); + let call = remark_call(); + let identifier = identifier_for(&call); + Limits::::insert(identifier, RateLimit::global(RateLimitKind::Exact(5))); + LastSeen::::insert(identifier, None::, 20); + + System::set_block_number(22); + + let err = + validate_with_tx_extension(&extension, &call).expect_err("should be rate limited"); + match err { + TransactionValidityError::Invalid(InvalidTransaction::Custom(code)) => { + assert_eq!(code, 1); + } + other => panic!("unexpected error: {:?}", other), + } + }); + } + + #[test] + fn tx_extension_skips_last_seen_when_span_zero() { + new_test_ext().execute_with(|| { + let extension = new_tx_extension(); + let call = remark_call(); + let identifier = identifier_for(&call); + Limits::::insert(identifier, RateLimit::global(RateLimitKind::Exact(0))); + + System::set_block_number(30); + + let (_valid, val, _) = validate_with_tx_extension(&extension, &call).expect("valid"); + assert!(val.is_none()); + + let info = call.get_dispatch_info(); + let len = call.encode().len(); + let origin_for_prepare = RuntimeOrigin::signed(42); + let pre = extension + .clone() + .prepare(val.clone(), &origin_for_prepare, &call, &info, len) + .expect("prepare succeeds"); + + let mut post = PostDispatchInfo::default(); + RateLimitTransactionExtension::::post_dispatch( + pre, + &info, + &mut post, + len, + &Ok(()), + ) + .expect("post_dispatch succeeds"); + + assert_eq!( + LastSeen::::get(identifier, None::), + None + ); + }); + } +} diff --git a/pallets/rate-limiting/src/types.rs b/pallets/rate-limiting/src/types.rs new file mode 100644 index 000000000..1daf2915b --- /dev/null +++ b/pallets/rate-limiting/src/types.rs @@ -0,0 +1,214 @@ +use codec::{Decode, DecodeWithMemTracking, Encode, MaxEncodedLen}; +use frame_support::{pallet_prelude::DispatchError, traits::GetCallMetadata}; +use scale_info::TypeInfo; +use sp_std::collections::btree_map::BTreeMap; + +/// Resolves the optional identifier within which a rate limit applies. +pub trait RateLimitContextResolver { + /// Returns `Some(context)` when the limit should be applied per-context, or `None` for global + /// limits. + fn context(call: &Call) -> Option; +} + +/// Identifies a runtime call by pallet and extrinsic indices. +#[cfg_attr(feature = "std", derive(serde::Deserialize, serde::Serialize))] +#[derive( + Clone, + Copy, + PartialEq, + Eq, + Encode, + Decode, + DecodeWithMemTracking, + TypeInfo, + MaxEncodedLen, + Debug, +)] +pub struct TransactionIdentifier { + /// Pallet variant index. + pub pallet_index: u8, + /// Call variant index within the pallet. + pub extrinsic_index: u8, +} + +impl TransactionIdentifier { + /// Builds a new identifier from pallet/extrinsic indices. + pub const fn new(pallet_index: u8, extrinsic_index: u8) -> Self { + Self { + pallet_index, + extrinsic_index, + } + } + + /// Returns the pallet and extrinsic names associated with this identifier. + pub fn names(&self) -> Result<(&'static str, &'static str), DispatchError> + where + T: crate::pallet::Config, + I: 'static, + >::RuntimeCall: GetCallMetadata, + { + let modules = >::RuntimeCall::get_module_names(); + let pallet_name = modules + .get(self.pallet_index as usize) + .copied() + .ok_or(crate::pallet::Error::::InvalidRuntimeCall)?; + let call_names = >::RuntimeCall::get_call_names(pallet_name); + let extrinsic_name = call_names + .get(self.extrinsic_index as usize) + .copied() + .ok_or(crate::pallet::Error::::InvalidRuntimeCall)?; + Ok((pallet_name, extrinsic_name)) + } + + /// Builds an identifier from a runtime call by extracting pallet/extrinsic indices. + pub fn from_call( + call: &>::RuntimeCall, + ) -> Result + where + T: crate::pallet::Config, + I: 'static, + { + call.using_encoded(|encoded| { + let pallet_index = *encoded + .get(0) + .ok_or(crate::pallet::Error::::InvalidRuntimeCall)?; + let extrinsic_index = *encoded + .get(1) + .ok_or(crate::pallet::Error::::InvalidRuntimeCall)?; + Ok(TransactionIdentifier::new(pallet_index, extrinsic_index)) + }) + } +} + +/// Policy describing the block span enforced by a rate limit. +#[cfg_attr(feature = "std", derive(serde::Deserialize, serde::Serialize))] +#[derive( + Clone, + Copy, + PartialEq, + Eq, + Encode, + Decode, + DecodeWithMemTracking, + TypeInfo, + MaxEncodedLen, + Debug, +)] +pub enum RateLimitKind { + /// Use the pallet-level default rate limit. + Default, + /// Apply an exact rate limit measured in blocks. + Exact(BlockNumber), +} + +/// Stored rate limit configuration for a transaction identifier. +/// +/// The configuration is mutually exclusive: either the call is globally limited or it stores a set +/// of per-scope spans. +#[cfg_attr(feature = "std", derive(serde::Deserialize, serde::Serialize))] +#[cfg_attr( + feature = "std", + serde( + bound = "Scope: Ord + serde::Serialize + serde::de::DeserializeOwned, BlockNumber: serde::Serialize + serde::de::DeserializeOwned" + ) +)] +#[derive(Clone, PartialEq, Eq, Encode, Decode, DecodeWithMemTracking, TypeInfo, Debug)] +pub enum RateLimit { + /// Global span applied to every invocation. + Global(RateLimitKind), + /// Per-scope spans keyed by `Scope`. + Scoped(BTreeMap>), +} + +impl RateLimit +where + Scope: Ord, +{ + /// Convenience helper to build a global configuration. + pub fn global(kind: RateLimitKind) -> Self { + Self::Global(kind) + } + + /// Convenience helper to build a scoped configuration containing a single entry. + pub fn scoped_single(scope: Scope, kind: RateLimitKind) -> Self { + let mut map = BTreeMap::new(); + map.insert(scope, kind); + Self::Scoped(map) + } + + /// Returns the span configured for the provided scope, if any. + pub fn kind_for(&self, scope: Option<&Scope>) -> Option<&RateLimitKind> { + match self { + RateLimit::Global(kind) => Some(kind), + RateLimit::Scoped(map) => scope.and_then(|key| map.get(key)), + } + } + + /// Inserts or updates a scoped entry, converting from a global configuration if needed. + pub fn upsert_scope(&mut self, scope: Scope, kind: RateLimitKind) { + match self { + RateLimit::Global(_) => { + let mut map = BTreeMap::new(); + map.insert(scope, kind); + *self = RateLimit::Scoped(map); + } + RateLimit::Scoped(map) => { + map.insert(scope, kind); + } + } + } + + /// Removes a scoped entry, returning whether one existed. + pub fn remove_scope(&mut self, scope: &Scope) -> bool { + match self { + RateLimit::Global(_) => false, + RateLimit::Scoped(map) => map.remove(scope).is_some(), + } + } + + /// Returns true when the scoped configuration contains no entries. + pub fn is_scoped_empty(&self) -> bool { + matches!(self, RateLimit::Scoped(map) if map.is_empty()) + } +} + +#[cfg(test)] +mod tests { + use sp_runtime::DispatchError; + + use super::*; + use crate::{mock::*, pallet::Error}; + + #[test] + fn transaction_identifier_from_call_matches_expected_indices() { + let call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 0 }); + + let identifier = TransactionIdentifier::from_call::(&call).expect("identifier"); + + // System is the first pallet in the mock runtime, RateLimiting is second. + assert_eq!(identifier.pallet_index, 1); + // set_default_rate_limit has call_index 2. + assert_eq!(identifier.extrinsic_index, 2); + } + + #[test] + fn transaction_identifier_names_matches_call_metadata() { + let call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 0 }); + let identifier = TransactionIdentifier::from_call::(&call).expect("identifier"); + + let (pallet, extrinsic) = identifier.names::().expect("call metadata"); + assert_eq!(pallet, "RateLimiting"); + assert_eq!(extrinsic, "set_default_rate_limit"); + } + + #[test] + fn transaction_identifier_names_error_for_unknown_indices() { + let identifier = TransactionIdentifier::new(99, 0); + + let err = identifier.names::().expect_err("should fail"); + let expected: DispatchError = Error::::InvalidRuntimeCall.into(); + assert_eq!(err, expected); + } +}