Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
744 changes: 716 additions & 28 deletions Cargo.lock

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ futures = "0.3.5"
http = "1.0"
http-body-util = "0.1"
hyper = { version = "1.0", features = ["full"] }
hyper-server = "0.6.0"
hyper-util = { version = "0.1", features = ["tokio", "server", "http1", "http2"] }
tokio = { version = "1.5.0", features = ["rt", "macros"] }
deadpool = "0.10.0"
Expand All @@ -36,9 +37,22 @@ assert-json-diff = "2.0.1"
base64 = "0.22"
url = "2.2"

rcgen = { version = "0.13.2", optional = true }
rustls-pki-types = { version = "1.11.0", optional = true }

[dev-dependencies]
async-std = { version = "1.13.0", features = ["attributes", "tokio1"] }
reqwest = { version = "0.12.7", features = ["json"] }
tokio = { version = "1.5.0", features = ["macros", "rt-multi-thread"] }
actix-rt = "2.2.0"
serde = { version = "1", features = ["derive"] }
criterion = "0.5.1"

[features]
default = []
# Enable HTTPS mock server support.
tls = ["hyper-server/tls-rustls", "rcgen", "rustls-pki-types", "reqwest/rustls-tls"]

[[bench]]
name = "tls_certs"
harness = false
34 changes: 34 additions & 0 deletions benches/tls_certs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#[cfg(feature = "tls")]
use criterion::{criterion_group, criterion_main, Criterion};

#[cfg(feature = "tls")]
use wiremock::tls_certs::MockTlsCertificates;

#[cfg(feature = "tls")]
// On the good old M1 processor it takes ~77 µs
pub fn tls_mock_tls_certificates_new(c: &mut Criterion) {
c.bench_function("MockTlsCertificates::new", |b| {
b.iter(|| MockTlsCertificates::random())
});
}

#[cfg(feature = "tls")]
pub fn tls_mock_tls_certificates_client(c: &mut Criterion) {
let mock_tls_certificates = MockTlsCertificates::random();
c.bench_function("MockTlsCertificates::gen_client", |b| {
b.iter(|| mock_tls_certificates.gen_client("[email protected]"))
});
}

#[cfg(feature = "tls")]
criterion_group!(
benches,
tls_mock_tls_certificates_new,
tls_mock_tls_certificates_client
);

#[cfg(feature = "tls")]
criterion_main!(benches);

#[cfg(not(feature = "tls"))]
fn main() {}
12 changes: 12 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@
//! The pool is designed to be invisible: it makes your life easier and your tests faster. If you
//! end up having to worry about it, it's a bug: open an issue!
//!
//! ## HTTPS
//!
//! You may start a HTTPS server with `MockServer::start_https(MockServerTlsConfig)` method.
//! The `MockServerTlsConfig` can be created from pregenerated TLS certificates, or you can
//! generate self-signed server and client certificates with a `MockTlsCertificates` instance.
//!
//! HTTPS servers are not pooled yet.
//!
//! HTTPS functionality is gated by the `tls` feature.
//!
//! ## Prior art
//!
//! [`mockito`] and [`httpmock`] provide HTTP mocking for Rust.
Expand Down Expand Up @@ -158,6 +168,8 @@ mod verification;
pub type ErrorResponse = Box<dyn std::error::Error + Send + Sync + 'static>;

pub use mock::{Match, Mock, MockBuilder, Times};
#[cfg(feature = "tls")]
pub use mock_server::tls_certs;
pub use mock_server::{MockGuard, MockServer, MockServerBuilder};
pub use request::Request;
pub use respond::Respond;
Expand Down
38 changes: 32 additions & 6 deletions src/mock_server/bare_server.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use crate::mock_server::hyper::run_server;
use crate::mock_server::hyper::{run_server, HyperRequestHandler};
use crate::mock_set::MockId;
use crate::mock_set::MountedMockSet;
use crate::request::BodyPrintLimit;
use crate::{mock::Mock, verification::VerificationOutcome, ErrorResponse, Request};
use http_body_util::Full;
use hyper::body::Bytes;
use hyper_server::accept::Accept;
use std::fmt::{Debug, Write};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::pin::pin;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::Notify;
use tokio::sync::RwLock;

Expand All @@ -22,6 +24,7 @@ use tokio::sync::RwLock;
pub(crate) struct BareMockServer {
state: Arc<RwLock<MockServerState>>,
server_address: SocketAddr,
proto: &'static str,
// When `_shutdown_trigger` gets dropped the listening server terminates gracefully.
_shutdown_trigger: tokio::sync::watch::Sender<()>,
}
Expand Down Expand Up @@ -52,11 +55,28 @@ impl MockServerState {
impl BareMockServer {
/// Start a new instance of a `BareMockServer` listening on the specified
/// [`TcpListener`].
pub(super) async fn start(
pub(super) async fn start<A>(
listener: TcpListener,
request_recording: RequestRecording,
body_print_limit: BodyPrintLimit,
) -> Self {
proto: &'static str,
acceptor: A,
) -> Self
where
A: Accept<tokio::net::TcpStream, HyperRequestHandler> + Send + Clone + 'static,
<A as Accept<tokio::net::TcpStream, HyperRequestHandler>>::Future: Send,
<A as Accept<tokio::net::TcpStream, HyperRequestHandler>>::Stream:
Unpin + Send + AsyncWrite + AsyncRead + 'static,
<A as Accept<tokio::net::TcpStream, HyperRequestHandler>>::Service:
hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<Full<Bytes>>>
+ Send,
<<A as Accept<tokio::net::TcpStream, HyperRequestHandler>>::Service as hyper::service::Service<
http::Request<hyper::body::Incoming>,
>>::Error: Send + Sync + Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
<<A as Accept<tokio::net::TcpStream, HyperRequestHandler>>::Service as hyper::service::Service<
http::Request<hyper::body::Incoming>,
>>::Future: Send + 'static,
{
let (shutdown_trigger, shutdown_receiver) = tokio::sync::watch::channel(());
let received_requests = match request_recording {
RequestRecording::Enabled => Some(Vec::new()),
Expand All @@ -73,7 +93,7 @@ impl BareMockServer {

let server_state = state.clone();
std::thread::spawn(move || {
let server_future = run_server(listener, server_state, shutdown_receiver);
let server_future = run_server(listener, server_state, shutdown_receiver, acceptor);

let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
Expand All @@ -94,6 +114,7 @@ impl BareMockServer {
Self {
state,
server_address,
proto,
_shutdown_trigger: shutdown_trigger,
}
}
Expand Down Expand Up @@ -146,7 +167,7 @@ impl BareMockServer {
/// Use this method to compose uris when interacting with this instance of `BareMockServer` via
/// an HTTP client.
pub(crate) fn uri(&self) -> String {
format!("http://{}", self.server_address)
format!("{}://{}", self.proto, self.server_address)
}

/// Return the socket address of this running instance of `BareMockServer`, e.g. `127.0.0.1:4372`.
Expand All @@ -173,7 +194,12 @@ impl BareMockServer {

impl Debug for BareMockServer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "BareMockServer {{ address: {} }}", self.address())
write!(
f,
"BareMockServer {{ proto: {}, address: {} }}",
self.proto,
self.address()
)
}
}

Expand Down
63 changes: 59 additions & 4 deletions src/mock_server/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@ use crate::mock_server::bare_server::{BareMockServer, RequestRecording};
use crate::mock_server::exposed_server::InnerServer;
use crate::request::{BodyPrintLimit, BODY_PRINT_LIMIT};
use crate::MockServer;
#[cfg(feature = "tls")]
use hyper_server::tls_rustls::RustlsConfig;
use std::env;
use std::net::TcpListener;

/// A builder providing a fluent API to assemble a [`MockServer`] step-by-step.
#[cfg(feature = "tls")]
use super::tls_certs::MockServerTlsConfig;

/// A builder providing a fluent API to assemble a [`MockServer`] step-by-step.
/// Use [`MockServer::builder`] to get started.
pub struct MockServerBuilder {
listener: Option<TcpListener>,
Expand Down Expand Up @@ -100,7 +105,36 @@ impl MockServerBuilder {
}

/// Finalise the builder to get an instance of a [`BareMockServer`].
pub(super) async fn build_bare(self) -> BareMockServer {
pub(super) async fn build_bare_http(self) -> BareMockServer {
use hyper_server::accept::DefaultAcceptor;

let listener = if let Some(listener) = self.listener {
listener
} else {
TcpListener::bind("127.0.0.1:0").expect("Failed to bind an OS port for a mock server.")
};
let recording = if self.record_incoming_requests {
RequestRecording::Enabled
} else {
RequestRecording::Disabled
};
BareMockServer::start(
listener,
recording,
self.body_print_limit,
"http",
DefaultAcceptor::new(),
)
.await
}

/// Finalise the builder to get an HTTPS instance of a [`BareMockServer`].
///
/// Panics if DER data the `certs` is invalid.
#[cfg(feature = "tls")]
pub(super) async fn build_bare_https(self, certs: MockServerTlsConfig) -> BareMockServer {
use hyper_server::tls_rustls::RustlsAcceptor;

let listener = if let Some(listener) = self.listener {
listener
} else {
Expand All @@ -111,11 +145,32 @@ impl MockServerBuilder {
} else {
RequestRecording::Disabled
};
BareMockServer::start(listener, recording, self.body_print_limit).await

let rustls_config = RustlsConfig::from_der(
vec![certs.server_cert_der, certs.root_cert_der],
certs.server_keypair_der,
)
.await
.expect("Failed to parse TLS configuration from DER data");

BareMockServer::start(
listener,
recording,
self.body_print_limit,
"https",
RustlsAcceptor::new(rustls_config),
)
.await
}

/// Finalise the builder and launch the [`MockServer`] instance!
pub async fn start(self) -> MockServer {
MockServer::new(InnerServer::Bare(self.build_bare().await))
MockServer::new(InnerServer::Bare(self.build_bare_http().await))
}

/// Finalise the builder and launch the HTTPS [`MockServer`] instance!
#[cfg(feature = "tls")]
pub async fn start_https(self, tls_conf: MockServerTlsConfig) -> MockServer {
MockServer::new(InnerServer::Bare(self.build_bare_https(tls_conf).await))
}
}
76 changes: 60 additions & 16 deletions src/mock_server/hyper.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,37 @@
use crate::mock_server::bare_server::MockServerState;
use hyper::service::service_fn;
use futures::future::{BoxFuture, FutureExt as _};
use http_body_util::Full;
use hyper::body::Bytes;
use hyper_server::accept::Accept;
use hyper_util::rt::TokioIo;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use tokio::sync::RwLock;

/// Work around a lifetime error where, for some reason,
/// `Box<dyn std::error::Error + Send + Sync + 'static>` can't be converted to a
/// `Box<dyn std::error::Error + Send + Sync>`
struct ErrorLifetimeCast(Box<dyn std::error::Error + Send + Sync + 'static>);
pub(super) struct ErrorLifetimeCast(Box<dyn std::error::Error + Send + Sync + 'static>);

impl From<ErrorLifetimeCast> for Box<dyn std::error::Error + Send + Sync> {
fn from(value: ErrorLifetimeCast) -> Self {
value.0
}
}

/// The actual HTTP server responding to incoming requests according to the specified mocks.
pub(super) async fn run_server(
listener: std::net::TcpListener,
#[derive(Clone)]
pub(super) struct HyperRequestHandler {
server_state: Arc<RwLock<MockServerState>>,
mut shutdown_signal: tokio::sync::watch::Receiver<()>,
) {
listener
.set_nonblocking(true)
.expect("Cannot set non-blocking mode on TcpListener");
let listener = TcpListener::from_std(listener).expect("Cannot upgrade TcpListener");
}

impl hyper::service::Service<hyper::Request<hyper::body::Incoming>> for HyperRequestHandler {
type Response = hyper::Response<Full<Bytes>>;
type Error = ErrorLifetimeCast;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

let request_handler = move |request| {
let server_state = server_state.clone();
fn call(&self, request: hyper::Request<hyper::body::Incoming>) -> Self::Future {
let server_state = self.server_state.clone();
async move {
let wiremock_request = crate::Request::from_hyper(request).await;
let (response, delay) = server_state
Expand All @@ -52,6 +55,37 @@ pub(super) async fn run_server(

Ok::<_, ErrorLifetimeCast>(response)
}
.boxed()
}
}

/// The actual HTTP server responding to incoming requests according to the specified mocks.
pub(super) async fn run_server<A>(
listener: std::net::TcpListener,
server_state: Arc<RwLock<MockServerState>>,
mut shutdown_signal: tokio::sync::watch::Receiver<()>,
acceptor: A,
) where
A: Accept<tokio::net::TcpStream, HyperRequestHandler> + Send + Clone + 'static,
<A as Accept<tokio::net::TcpStream, HyperRequestHandler>>::Future: Send,
<A as Accept<tokio::net::TcpStream, HyperRequestHandler>>::Stream:
Unpin + Send + AsyncWrite + AsyncRead + 'static,
<A as Accept<tokio::net::TcpStream, HyperRequestHandler>>::Service:
hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<Full<Bytes>>>
+ Send,
<<A as Accept<tokio::net::TcpStream, HyperRequestHandler>>::Service
as hyper::service::Service<http::Request<hyper::body::Incoming>>>::Error:
Send + Sync + Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
<<A as Accept<tokio::net::TcpStream, HyperRequestHandler>>::Service
as hyper::service::Service<http::Request<hyper::body::Incoming>>>::Future: Send + 'static,
{
listener
.set_nonblocking(true)
.expect("Cannot set non-blocking mode on TcpListener");
let listener = TcpListener::from_std(listener).expect("Cannot upgrade TcpListener");

let request_handler = HyperRequestHandler {
server_state: server_state.clone(),
};

loop {
Expand All @@ -67,14 +101,24 @@ pub(super) async fn run_server(
break;
}
};
let io = TokioIo::new(stream);

let request_handler = request_handler.clone();
let mut shutdown_signal = shutdown_signal.clone();
let acceptor = acceptor.clone();
tokio::task::spawn(async move {
let accept = acceptor.accept(stream, request_handler).await;
let (stream, request_service) = match accept {
Ok((stream, service)) => (stream, service),
Err(e) => {
log::error!("Failed to accept connection: {}", e);
return;
}
};

let io = TokioIo::new(stream);

let http_server =
hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new());
let conn = http_server.serve_connection_with_upgrades(io, service_fn(request_handler));
let conn = http_server.serve_connection_with_upgrades(io, request_service);
tokio::pin!(conn);

loop {
Expand Down
Loading
Loading