From 070e462cfb9243beed4da3196f650be33d911156 Mon Sep 17 00:00:00 2001 From: David Pacheco Date: Thu, 26 Sep 2024 20:24:17 -0700 Subject: [PATCH 1/4] it compiles --- Cargo.lock | 9 +- dropshot/Cargo.toml | 1 + dropshot/src/lib.rs | 2 + dropshot/src/server.rs | 271 ++++++++++++++++++++++++++++++++++++----- 4 files changed, 248 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 06d7bcfb9..bb3458144 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -396,6 +396,7 @@ dependencies = [ "slog-term", "subprocess", "tempfile", + "thiserror", "tokio", "tokio-rustls", "tokio-tungstenite", @@ -1838,18 +1839,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.56" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" +checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.56" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" +checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" dependencies = [ "proc-macro2", "quote", diff --git a/dropshot/Cargo.toml b/dropshot/Cargo.toml index 21c908a87..f41d22a01 100644 --- a/dropshot/Cargo.toml +++ b/dropshot/Cargo.toml @@ -42,6 +42,7 @@ slog-async = "2.8.0" slog-bunyan = "2.5.0" slog-json = "2.6.1" slog-term = "2.9.1" +thiserror = "1.0.64" tokio-rustls = "0.25.0" toml = "0.8.19" waitgroup = "0.1.2" diff --git a/dropshot/src/lib.rs b/dropshot/src/lib.rs index 8716777d9..e4889ccab 100644 --- a/dropshot/src/lib.rs +++ b/dropshot/src/lib.rs @@ -836,6 +836,8 @@ pub use pagination::PaginationOrder; pub use pagination::PaginationParams; pub use pagination::ResultsPage; pub use pagination::WhichPage; +pub use server::BuildError; +pub use server::ServerBuilder; pub use server::ServerContext; pub use server::ShutdownWaitFuture; pub use server::{HttpServer, HttpServerStarter}; diff --git a/dropshot/src/server.rs b/dropshot/src/server.rs index 8816c9f77..cc0037b85 100644 --- a/dropshot/src/server.rs +++ b/dropshot/src/server.rs @@ -33,6 +33,7 @@ use std::panic; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use thiserror::Error; use tokio::io::ReadBuf; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::oneshot; @@ -314,13 +315,21 @@ impl InnerHttpServerStarter { private: C, log: &Logger, handler_waitgroup_worker: waitgroup::Worker, - ) -> Result, std::io::Error> { + ) -> Result, BuildError> { + // XXX-dap this is mostly duplicated from the Https version // We use `from_std` instead of just calling `bind` here directly // to avoid invoking an async function. - let std_listener = std::net::TcpListener::bind(&config.bind_address)?; - std_listener.set_nonblocking(true)?; - let tcp = TcpListener::from_std(std_listener)?; - let local_addr = tcp.local_addr()?; + let std_listener = std::net::TcpListener::bind(&config.bind_address) + .map_err(|e| BuildError::bind_error(e, config.bind_address))?; + std_listener.set_nonblocking(true).map_err(|e| { + BuildError::generic_system(e, "setting non-blocking") + })?; + let tcp = TcpListener::from_std(std_listener).map_err(|e| { + BuildError::generic_system(e, "creating TCP listener") + })?; + let local_addr = tcp.local_addr().map_err(|e| { + BuildError::generic_system(e, "getting local TCP address") + })?; let incoming = HttpAcceptor { tcp, log: log.new(o!("local_addr" => local_addr)) }; @@ -515,9 +524,9 @@ struct InnerHttpsServerStarter( /// Create a TLS configuration from the Dropshot config structure. impl TryFrom<&ConfigTls> for rustls::ServerConfig { - type Error = std::io::Error; + type Error = BuildError; - fn try_from(config: &ConfigTls) -> std::io::Result { + fn try_from(config: &ConfigTls) -> Result { let (mut cert_reader, mut key_reader): ( Box, Box, @@ -532,25 +541,17 @@ impl TryFrom<&ConfigTls> for rustls::ServerConfig { ConfigTls::AsFile { cert_file, key_file } => { let certfile = Box::new(std::io::BufReader::new( std::fs::File::open(cert_file).map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::Other, - format!( - "failed to open {}: {}", - cert_file.display(), - e - ), + BuildError::generic_system( + e, + format!("opening {}", cert_file.display()), ) })?, )); let keyfile = Box::new(std::io::BufReader::new( std::fs::File::open(key_file).map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::Other, - format!( - "failed to open {}: {}", - key_file.display(), - e - ), + BuildError::generic_system( + e, + format!("opening {}", key_file.display()), ) })?, )); @@ -561,17 +562,17 @@ impl TryFrom<&ConfigTls> for rustls::ServerConfig { let certs = rustls_pemfile::certs(&mut cert_reader) .collect::, _>>() .map_err(|err| { - io_error(format!("failed to load certificate: {err}")) + BuildError::generic_system(err, "loading TLS certificates") })?; let keys = rustls_pemfile::pkcs8_private_keys(&mut key_reader) .collect::, _>>() .map_err(|err| { - io_error(format!("failed to load private key: {err}")) + BuildError::generic_system(err, "loading TLS private key") })?; let mut keys_iter = keys.into_iter(); let (Some(private_key), None) = (keys_iter.next(), keys_iter.next()) else { - return Err(io_error("expected a single private key".into())); + return Err(BuildError::NotOnePrivateKey); }; let mut cfg = rustls::ServerConfig::builder() @@ -639,21 +640,28 @@ impl InnerHttpsServerStarter { log: &Logger, tls: &ConfigTls, handler_waitgroup_worker: waitgroup::Worker, - ) -> Result, GenericError> { + ) -> Result, BuildError> { let acceptor = Arc::new(Mutex::new(TlsAcceptor::from(Arc::new( rustls::ServerConfig::try_from(tls)?, )))); let tcp = { - let listener = std::net::TcpListener::bind(&config.bind_address)?; - listener.set_nonblocking(true)?; + let listener = std::net::TcpListener::bind(&config.bind_address) + .map_err(|e| BuildError::bind_error(e, config.bind_address))?; + listener.set_nonblocking(true).map_err(|e| { + BuildError::generic_system(e, "setting non-blocking") + })?; // We use `from_std` instead of just calling `bind` here directly // to avoid invoking an async function, to match the interface // provided by `HttpServerStarter::new`. - TcpListener::from_std(listener)? + TcpListener::from_std(listener).map_err(|e| { + BuildError::generic_system(e, "creating TCP listener") + })? }; - let local_addr = tcp.local_addr()?; + let local_addr = tcp.local_addr().map_err(|e| { + BuildError::generic_system(e, "getting local TCP address") + })?; let logger = log.new(o!("local_addr" => local_addr)); let tcp = HttpAcceptor { tcp, log: logger.clone() }; let https_acceptor = @@ -1142,8 +1150,209 @@ impl Service> } } -fn io_error(err: String) -> std::io::Error { - std::io::Error::new(std::io::ErrorKind::Other, err) +#[derive(Debug, Error)] +pub enum BuildError { + #[error("failed to bind to {address}")] + BindError { + address: SocketAddr, + #[source] + error: std::io::Error, + }, + #[error("expected exactly one TLS private key")] + NotOnePrivateKey, + #[error("must register an API")] + MissingApi, + #[error("only one API can be registered with a server")] + TooManyApis, + #[error("{context}")] + SystemError { + context: String, + #[source] + error: std::io::Error, + }, +} + +impl BuildError { + fn bind_error(error: std::io::Error, address: SocketAddr) -> BuildError { + BuildError::BindError { address, error } + } + + fn generic_system>( + error: std::io::Error, + context: S, + ) -> BuildError { + BuildError::SystemError { context: context.into(), error } + } +} + +#[derive(Debug)] +pub struct ServerBuilder { + // required caller-provided values + private: C, + log: Logger, + + // optional caller-provided values + config: ConfigDropshot, + tls: Option, + api: DebugIgnore>>, + + // our own internal state + error: Option, +} + +impl ServerBuilder { + pub fn new(log: Logger, private: C) -> ServerBuilder { + ServerBuilder { + private, + log, + config: Default::default(), + tls: Default::default(), + api: Default::default(), + error: Default::default(), + } + } + + pub fn config(mut self, config: ConfigDropshot) -> Self { + self.config = config; + self + } + + pub fn tls(mut self, tls: Option) -> Self { + self.tls = tls; + self + } + + pub fn api(mut self, api: ApiDescription) -> Self { + if self.api.is_none() { + self.api = DebugIgnore(Some(api)); + } else { + self.error(BuildError::TooManyApis); + } + + self + } + + fn error(&mut self, error: BuildError) { + if self.error.is_none() { + self.error = Some(error); + } + } + + pub fn build(self) -> Result, BuildError> { + let server_config = ServerConfig { + // We start aggressively to ensure test coverage. + request_body_max_bytes: self.config.request_body_max_bytes, + page_max_nitems: NonZeroU32::new(10000).unwrap(), + page_default_nitems: NonZeroU32::new(100).unwrap(), + default_handler_task_mode: self.config.default_handler_task_mode, + log_headers: self.config.log_headers.clone(), + }; + let handler_waitgroup = WaitGroup::new(); + + let config = self.config; + let private = self.private; + let log = self.log; + let tls = self.tls; + let api = self.api.0.ok_or_else(|| BuildError::MissingApi)?; + + let starter = if let Some(tls) = &tls { + let (starter, app_state, local_addr) = + InnerHttpsServerStarter::new( + &config, + server_config, + api, + private, + &log, + tls, + handler_waitgroup.worker(), + )?; + HttpServerStarter { + app_state, + local_addr, + wrapped: WrappedHttpServerStarter::Https(starter), + handler_waitgroup, + } + } else { + let (starter, app_state, local_addr) = InnerHttpServerStarter::new( + &config, + server_config, + api, + private, + &log, + handler_waitgroup.worker(), + )?; + HttpServerStarter { + app_state, + local_addr, + wrapped: WrappedHttpServerStarter::Http(starter), + handler_waitgroup, + } + }; + + let log = &starter.app_state.log; + for (path, method, _) in &starter.app_state.router { + debug!(log, "registered endpoint"; + "method" => &method, + "path" => &path + ); + } + + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + let log_close = starter.app_state.log.new(o!()); + let join_handle = match starter.wrapped { + WrappedHttpServerStarter::Http(http) => http.start(rx, log_close), + WrappedHttpServerStarter::Https(https) => { + https.start(rx, log_close) + } + }; + info!(log, "listening"); + + let handler_waitgroup = starter.handler_waitgroup; + let join_handle = async move { + // After the server shuts down, we also want to wait for any + // detached handler futures to complete. + () = join_handle + .await + .map_err(|e| format!("server stopped: {e}"))?; + () = handler_waitgroup.wait().await; + Ok(()) + }; + + #[cfg(feature = "usdt-probes")] + let probe_registration = match usdt::register_probes() { + Ok(_) => { + debug!( + starter.app_state.log, + "successfully registered DTrace USDT probes" + ); + ProbeRegistration::Succeeded + } + Err(e) => { + let msg = e.to_string(); + error!( + starter.app_state.log, + "failed to register DTrace USDT probes: {}", msg + ); + ProbeRegistration::Failed(msg) + } + }; + #[cfg(not(feature = "usdt-probes"))] + let probe_registration = { + debug!( + starter.app_state.log, + "DTrace USDT probes compiled out, not registering" + ); + ProbeRegistration::Disabled + }; + + Ok(HttpServer { + probe_registration, + app_state: starter.app_state, + local_addr: starter.local_addr, + closer: CloseHandle { close_channel: Some(tx) }, + join_future: join_handle.boxed().shared(), + }) + } } #[cfg(test)] From 95a0071c3bba86d80134d2f679ddbcf32a58cd34 Mon Sep 17 00:00:00 2001 From: David Pacheco Date: Thu, 26 Sep 2024 20:49:07 -0700 Subject: [PATCH 2/4] commonize a bunch of the HTTP and HTTPS startup code --- dropshot/src/server.rs | 491 ++++++++++++++++++++--------------------- 1 file changed, 236 insertions(+), 255 deletions(-) diff --git a/dropshot/src/server.rs b/dropshot/src/server.rs index cc0037b85..fcb72dfbf 100644 --- a/dropshot/src/server.rs +++ b/dropshot/src/server.rs @@ -122,126 +122,128 @@ impl HttpServerStarter { } pub fn new_with_tls( - config: &ConfigDropshot, - api: ApiDescription, - private: C, - log: &Logger, - tls: Option, + _config: &ConfigDropshot, + _api: ApiDescription, + _private: C, + _log: &Logger, + _tls: Option, ) -> Result, GenericError> { - let server_config = ServerConfig { - // We start aggressively to ensure test coverage. - request_body_max_bytes: config.request_body_max_bytes, - page_max_nitems: NonZeroU32::new(10000).unwrap(), - page_default_nitems: NonZeroU32::new(100).unwrap(), - default_handler_task_mode: config.default_handler_task_mode, - log_headers: config.log_headers.clone(), - }; - - let handler_waitgroup = WaitGroup::new(); - let starter = match &tls { - Some(tls) => { - let (starter, app_state, local_addr) = - InnerHttpsServerStarter::new( - config, - server_config, - api, - private, - log, - tls, - handler_waitgroup.worker(), - )?; - HttpServerStarter { - app_state, - local_addr, - wrapped: WrappedHttpServerStarter::Https(starter), - handler_waitgroup, - } - } - None => { - let (starter, app_state, local_addr) = - InnerHttpServerStarter::new( - config, - server_config, - api, - private, - log, - handler_waitgroup.worker(), - )?; - HttpServerStarter { - app_state, - local_addr, - wrapped: WrappedHttpServerStarter::Http(starter), - handler_waitgroup, - } - } - }; - - for (path, method, _) in &starter.app_state.router { - debug!(starter.app_state.log, "registered endpoint"; - "method" => &method, - "path" => &path - ); - } - - Ok(starter) + todo!(); // XXX-dap + // let server_config = ServerConfig { + // // We start aggressively to ensure test coverage. + // request_body_max_bytes: config.request_body_max_bytes, + // page_max_nitems: NonZeroU32::new(10000).unwrap(), + // page_default_nitems: NonZeroU32::new(100).unwrap(), + // default_handler_task_mode: config.default_handler_task_mode, + // log_headers: config.log_headers.clone(), + // }; + + // let handler_waitgroup = WaitGroup::new(); + // let starter = match &tls { + // Some(tls) => { + // let (starter, app_state, local_addr) = + // InnerHttpsServerStarter::new( + // config, + // server_config, + // api, + // private, + // log, + // tls, + // handler_waitgroup.worker(), + // )?; + // HttpServerStarter { + // app_state, + // local_addr, + // wrapped: WrappedHttpServerStarter::Https(starter), + // handler_waitgroup, + // } + // } + // None => { + // let (starter, app_state, local_addr) = + // InnerHttpServerStarter::new( + // config, + // server_config, + // api, + // private, + // log, + // handler_waitgroup.worker(), + // )?; + // HttpServerStarter { + // app_state, + // local_addr, + // wrapped: WrappedHttpServerStarter::Http(starter), + // handler_waitgroup, + // } + // } + // }; + + // for (path, method, _) in &starter.app_state.router { + // debug!(starter.app_state.log, "registered endpoint"; + // "method" => &method, + // "path" => &path + // ); + // } + + // Ok(starter) } pub fn start(self) -> HttpServer { - let (tx, rx) = tokio::sync::oneshot::channel::<()>(); - let log_close = self.app_state.log.new(o!()); - let join_handle = match self.wrapped { - WrappedHttpServerStarter::Http(http) => http.start(rx, log_close), - WrappedHttpServerStarter::Https(https) => { - https.start(rx, log_close) - } - }; - info!(self.app_state.log, "listening"); - - let handler_waitgroup = self.handler_waitgroup; - let join_handle = async move { - // After the server shuts down, we also want to wait for any - // detached handler futures to complete. - () = join_handle - .await - .map_err(|e| format!("server stopped: {e}"))?; - () = handler_waitgroup.wait().await; - Ok(()) - }; - - #[cfg(feature = "usdt-probes")] - let probe_registration = match usdt::register_probes() { - Ok(_) => { - debug!( - self.app_state.log, - "successfully registered DTrace USDT probes" - ); - ProbeRegistration::Succeeded - } - Err(e) => { - let msg = e.to_string(); - error!( - self.app_state.log, - "failed to register DTrace USDT probes: {}", msg - ); - ProbeRegistration::Failed(msg) - } - }; - #[cfg(not(feature = "usdt-probes"))] - let probe_registration = { - debug!( - self.app_state.log, - "DTrace USDT probes compiled out, not registering" - ); - ProbeRegistration::Disabled - }; - - HttpServer { - probe_registration, - app_state: self.app_state, - local_addr: self.local_addr, - closer: CloseHandle { close_channel: Some(tx) }, - join_future: join_handle.boxed().shared(), - } + todo!(); + // let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + // let log_close = self.app_state.log.new(o!()); + // let join_handle = match self.wrapped { + // WrappedHttpServerStarter::Http(http) => http.start(rx, log_close), + // WrappedHttpServerStarter::Https(https) => { + // https.start(rx, log_close) + // } + // }; + // info!(self.app_state.log, "listening"); + + // let handler_waitgroup = self.handler_waitgroup; + // let join_handle = async move { + // // After the server shuts down, we also want to wait for any + // // detached handler futures to complete. + // () = join_handle + // .await + // .map_err(|e| format!("server stopped: {e}"))?; + // () = handler_waitgroup.wait().await; + // Ok(()) + // }; + + // #[cfg(feature = "usdt-probes")] + // let probe_registration = match usdt::register_probes() { + // Ok(_) => { + // debug!( + // self.app_state.log, + // "successfully registered DTrace USDT probes" + // ); + // ProbeRegistration::Succeeded + // } + // Err(e) => { + // let msg = e.to_string(); + // error!( + // self.app_state.log, + // "failed to register DTrace USDT probes: {}", msg + // ); + // ProbeRegistration::Failed(msg) + // } + // }; + // #[cfg(not(feature = "usdt-probes"))] + // let probe_registration = { + // debug!( + // self.app_state.log, + // "DTrace USDT probes compiled out, not registering" + // ); + // ProbeRegistration::Disabled + // }; + + // HttpServer { + // probe_registration, + // app_state: self.app_state, + // local_addr: self.local_addr, + // closer: CloseHandle { close_channel: Some(tx) }, + // join_future: join_handle.boxed().shared(), + // } } } @@ -309,46 +311,33 @@ impl InnerHttpServerStarter { /// of `HttpServerStarter` (and await the result) to actually start the /// server. fn new( - config: &ConfigDropshot, - server_config: ServerConfig, - api: ApiDescription, - private: C, - log: &Logger, - handler_waitgroup_worker: waitgroup::Worker, + _config: &ConfigDropshot, + _server_config: ServerConfig, + _api: ApiDescription, + _private: C, + _log: &Logger, + _handler_waitgroup_worker: waitgroup::Worker, ) -> Result, BuildError> { - // XXX-dap this is mostly duplicated from the Https version - // We use `from_std` instead of just calling `bind` here directly - // to avoid invoking an async function. - let std_listener = std::net::TcpListener::bind(&config.bind_address) - .map_err(|e| BuildError::bind_error(e, config.bind_address))?; - std_listener.set_nonblocking(true).map_err(|e| { - BuildError::generic_system(e, "setting non-blocking") - })?; - let tcp = TcpListener::from_std(std_listener).map_err(|e| { - BuildError::generic_system(e, "creating TCP listener") - })?; - let local_addr = tcp.local_addr().map_err(|e| { - BuildError::generic_system(e, "getting local TCP address") - })?; - let incoming = - HttpAcceptor { tcp, log: log.new(o!("local_addr" => local_addr)) }; - - let app_state = Arc::new(DropshotState { - private, - config: server_config, - router: api.into_router(), - log: log.new(o!("local_addr" => local_addr)), - local_addr, - tls_acceptor: None, - handler_waitgroup_worker: DebugIgnore(handler_waitgroup_worker), - }); - - let make_service = ServerConnectionHandler::new(app_state.clone()); - Ok(( - InnerHttpServerStarter(incoming, make_service), - app_state, - local_addr, - )) + todo!(); + // let incoming = + // HttpAcceptor { tcp, log: log.new(o!("local_addr" => local_addr)) }; + + // let app_state = Arc::new(DropshotState { + // private, + // config: server_config, + // router: api.into_router(), + // log: logger, + // local_addr, + // tls_acceptor: None, + // handler_waitgroup_worker: DebugIgnore(handler_waitgroup_worker), + // }); + + // let make_service = ServerConnectionHandler::new(app_state.clone()); + // Ok(( + // InnerHttpServerStarter(incoming, make_service), + // app_state, + // local_addr, + // )) } } @@ -633,57 +622,40 @@ impl InnerHttpsServerStarter { } fn new( - config: &ConfigDropshot, - server_config: ServerConfig, - api: ApiDescription, - private: C, - log: &Logger, - tls: &ConfigTls, - handler_waitgroup_worker: waitgroup::Worker, + _config: &ConfigDropshot, + _server_config: ServerConfig, + _api: ApiDescription, + _private: C, + _log: &Logger, + _tls: &ConfigTls, + _handler_waitgroup_worker: waitgroup::Worker, ) -> Result, BuildError> { - let acceptor = Arc::new(Mutex::new(TlsAcceptor::from(Arc::new( - rustls::ServerConfig::try_from(tls)?, - )))); - - let tcp = { - let listener = std::net::TcpListener::bind(&config.bind_address) - .map_err(|e| BuildError::bind_error(e, config.bind_address))?; - listener.set_nonblocking(true).map_err(|e| { - BuildError::generic_system(e, "setting non-blocking") - })?; - // We use `from_std` instead of just calling `bind` here directly - // to avoid invoking an async function, to match the interface - // provided by `HttpServerStarter::new`. - TcpListener::from_std(listener).map_err(|e| { - BuildError::generic_system(e, "creating TCP listener") - })? - }; - - let local_addr = tcp.local_addr().map_err(|e| { - BuildError::generic_system(e, "getting local TCP address") - })?; - let logger = log.new(o!("local_addr" => local_addr)); - let tcp = HttpAcceptor { tcp, log: logger.clone() }; - let https_acceptor = - HttpsAcceptor::new(logger.clone(), acceptor.clone(), tcp); - - let app_state = Arc::new(DropshotState { - private, - config: server_config, - router: api.into_router(), - log: logger, - local_addr, - tls_acceptor: Some(acceptor), - handler_waitgroup_worker: DebugIgnore(handler_waitgroup_worker), - }); - - let make_service = ServerConnectionHandler::new(Arc::clone(&app_state)); - - Ok(( - InnerHttpsServerStarter(https_acceptor, make_service), - app_state, - local_addr, - )) + todo!(); + // let acceptor = Arc::new(Mutex::new(TlsAcceptor::from(Arc::new( + // rustls::ServerConfig::try_from(tls)?, + // )))); + + // let tcp = HttpAcceptor { tcp, log: logger.clone() }; + // let https_acceptor = + // HttpsAcceptor::new(logger.clone(), acceptor.clone(), tcp); + + // let app_state = Arc::new(DropshotState { + // private, + // config: server_config, + // router: api.into_router(), + // log: logger, + // local_addr, + // tls_acceptor: Some(acceptor), + // handler_waitgroup_worker: DebugIgnore(handler_waitgroup_worker), + // }); + + // let make_service = ServerConnectionHandler::new(Arc::clone(&app_state)); + + // Ok(( + // InnerHttpsServerStarter(https_acceptor, make_service), + // app_state, + // local_addr, + // )) } } @@ -1255,42 +1227,61 @@ impl ServerBuilder { let tls = self.tls; let api = self.api.0.ok_or_else(|| BuildError::MissingApi)?; - let starter = if let Some(tls) = &tls { - let (starter, app_state, local_addr) = - InnerHttpsServerStarter::new( - &config, - server_config, - api, - private, - &log, - tls, - handler_waitgroup.worker(), - )?; - HttpServerStarter { - app_state, - local_addr, - wrapped: WrappedHttpServerStarter::Https(starter), - handler_waitgroup, - } - } else { - let (starter, app_state, local_addr) = InnerHttpServerStarter::new( - &config, - server_config, - api, - private, - &log, - handler_waitgroup.worker(), - )?; - HttpServerStarter { - app_state, - local_addr, - wrapped: WrappedHttpServerStarter::Http(starter), - handler_waitgroup, + let std_listener = std::net::TcpListener::bind(&config.bind_address) + .map_err(|e| BuildError::bind_error(e, config.bind_address))?; + std_listener.set_nonblocking(true).map_err(|e| { + BuildError::generic_system(e, "setting non-blocking") + })?; + // We use `from_std` instead of just calling `bind` here directly + // to avoid invoking an async function. + let tcp = TcpListener::from_std(std_listener).map_err(|e| { + BuildError::generic_system(e, "creating TCP listener") + })?; + let local_addr = tcp.local_addr().map_err(|e| { + BuildError::generic_system(e, "getting local TCP address") + })?; + + let log = log.new(o!("local_addr" => local_addr)); + + let tls_acceptor = tls + .as_ref() + .map(|tls| { + Ok(Arc::new(Mutex::new(TlsAcceptor::from(Arc::new( + rustls::ServerConfig::try_from(tls)?, + ))))) + }) + .transpose()?; + + let app_state = Arc::new(DropshotState { + private, + config: server_config, + router: api.into_router(), + log: log.clone(), + local_addr, + tls_acceptor: tls_acceptor.clone(), + handler_waitgroup_worker: DebugIgnore(handler_waitgroup.worker()), + }); + let make_service = ServerConnectionHandler::new(Arc::clone(&app_state)); + + let incoming = HttpAcceptor { tcp, log: log.clone() }; + + let inner_starter = match tls_acceptor { + Some(tls_acceptor) => { + let https_acceptor = + HttpsAcceptor::new(log.clone(), tls_acceptor, incoming); + WrappedHttpServerStarter::Https(InnerHttpsServerStarter( + https_acceptor, + make_service, + )) } + None => WrappedHttpServerStarter::Http(InnerHttpServerStarter( + incoming, + make_service, + )), }; - let log = &starter.app_state.log; - for (path, method, _) in &starter.app_state.router { + let log = &app_state.log; + for (path, method, _) in &app_state.router { debug!(log, "registered endpoint"; "method" => &method, "path" => &path @@ -1298,8 +1289,8 @@ impl ServerBuilder { } let (tx, rx) = tokio::sync::oneshot::channel::<()>(); - let log_close = starter.app_state.log.new(o!()); - let join_handle = match starter.wrapped { + let log_close = app_state.log.new(o!()); + let join_handle = match inner_starter { WrappedHttpServerStarter::Http(http) => http.start(rx, log_close), WrappedHttpServerStarter::Https(https) => { https.start(rx, log_close) @@ -1307,7 +1298,6 @@ impl ServerBuilder { }; info!(log, "listening"); - let handler_waitgroup = starter.handler_waitgroup; let join_handle = async move { // After the server shuts down, we also want to wait for any // detached handler futures to complete. @@ -1321,34 +1311,25 @@ impl ServerBuilder { #[cfg(feature = "usdt-probes")] let probe_registration = match usdt::register_probes() { Ok(_) => { - debug!( - starter.app_state.log, - "successfully registered DTrace USDT probes" - ); + debug!(&log, "successfully registered DTrace USDT probes"); ProbeRegistration::Succeeded } Err(e) => { let msg = e.to_string(); - error!( - starter.app_state.log, - "failed to register DTrace USDT probes: {}", msg - ); + error!(&log, "failed to register DTrace USDT probes: {}", msg); ProbeRegistration::Failed(msg) } }; #[cfg(not(feature = "usdt-probes"))] let probe_registration = { - debug!( - starter.app_state.log, - "DTrace USDT probes compiled out, not registering" - ); + debug!(&log, "DTrace USDT probes compiled out, not registering"); ProbeRegistration::Disabled }; Ok(HttpServer { probe_registration, - app_state: starter.app_state, - local_addr: starter.local_addr, + app_state, + local_addr, closer: CloseHandle { close_channel: Some(tx) }, join_future: join_handle.boxed().shared(), }) From c7319d5cef2194eea60db5d7f64b653be776bdd2 Mon Sep 17 00:00:00 2001 From: David Pacheco Date: Thu, 26 Sep 2024 20:53:00 -0700 Subject: [PATCH 3/4] remove the now-unused code and impl the old interface in terms of the new one --- dropshot/src/server.rs | 214 +++-------------------------------------- 1 file changed, 15 insertions(+), 199 deletions(-) diff --git a/dropshot/src/server.rs b/dropshot/src/server.rs index fcb72dfbf..691225171 100644 --- a/dropshot/src/server.rs +++ b/dropshot/src/server.rs @@ -105,10 +105,7 @@ pub struct ServerConfig { } pub struct HttpServerStarter { - app_state: Arc>, - local_addr: SocketAddr, - wrapped: WrappedHttpServerStarter, - handler_waitgroup: WaitGroup, + server: HttpServer, } impl HttpServerStarter { @@ -122,128 +119,24 @@ impl HttpServerStarter { } pub fn new_with_tls( - _config: &ConfigDropshot, - _api: ApiDescription, - _private: C, - _log: &Logger, - _tls: Option, + config: &ConfigDropshot, + api: ApiDescription, + private: C, + log: &Logger, + tls: Option, ) -> Result, GenericError> { - todo!(); // XXX-dap - // let server_config = ServerConfig { - // // We start aggressively to ensure test coverage. - // request_body_max_bytes: config.request_body_max_bytes, - // page_max_nitems: NonZeroU32::new(10000).unwrap(), - // page_default_nitems: NonZeroU32::new(100).unwrap(), - // default_handler_task_mode: config.default_handler_task_mode, - // log_headers: config.log_headers.clone(), - // }; - - // let handler_waitgroup = WaitGroup::new(); - // let starter = match &tls { - // Some(tls) => { - // let (starter, app_state, local_addr) = - // InnerHttpsServerStarter::new( - // config, - // server_config, - // api, - // private, - // log, - // tls, - // handler_waitgroup.worker(), - // )?; - // HttpServerStarter { - // app_state, - // local_addr, - // wrapped: WrappedHttpServerStarter::Https(starter), - // handler_waitgroup, - // } - // } - // None => { - // let (starter, app_state, local_addr) = - // InnerHttpServerStarter::new( - // config, - // server_config, - // api, - // private, - // log, - // handler_waitgroup.worker(), - // )?; - // HttpServerStarter { - // app_state, - // local_addr, - // wrapped: WrappedHttpServerStarter::Http(starter), - // handler_waitgroup, - // } - // } - // }; - - // for (path, method, _) in &starter.app_state.router { - // debug!(starter.app_state.log, "registered endpoint"; - // "method" => &method, - // "path" => &path - // ); - // } - - // Ok(starter) + Ok(Self { + server: ServerBuilder::new(log.clone(), private) + .tls(tls) + .api(api) + .config(config.clone()) + .build() + .map_err(Box::new)?, + }) } pub fn start(self) -> HttpServer { - todo!(); - // let (tx, rx) = tokio::sync::oneshot::channel::<()>(); - // let log_close = self.app_state.log.new(o!()); - // let join_handle = match self.wrapped { - // WrappedHttpServerStarter::Http(http) => http.start(rx, log_close), - // WrappedHttpServerStarter::Https(https) => { - // https.start(rx, log_close) - // } - // }; - // info!(self.app_state.log, "listening"); - - // let handler_waitgroup = self.handler_waitgroup; - // let join_handle = async move { - // // After the server shuts down, we also want to wait for any - // // detached handler futures to complete. - // () = join_handle - // .await - // .map_err(|e| format!("server stopped: {e}"))?; - // () = handler_waitgroup.wait().await; - // Ok(()) - // }; - - // #[cfg(feature = "usdt-probes")] - // let probe_registration = match usdt::register_probes() { - // Ok(_) => { - // debug!( - // self.app_state.log, - // "successfully registered DTrace USDT probes" - // ); - // ProbeRegistration::Succeeded - // } - // Err(e) => { - // let msg = e.to_string(); - // error!( - // self.app_state.log, - // "failed to register DTrace USDT probes: {}", msg - // ); - // ProbeRegistration::Failed(msg) - // } - // }; - // #[cfg(not(feature = "usdt-probes"))] - // let probe_registration = { - // debug!( - // self.app_state.log, - // "DTrace USDT probes compiled out, not registering" - // ); - // ProbeRegistration::Disabled - // }; - - // HttpServer { - // probe_registration, - // app_state: self.app_state, - // local_addr: self.local_addr, - // closer: CloseHandle { close_channel: Some(tx) }, - // join_future: join_handle.boxed().shared(), - // } + self.server } } @@ -257,9 +150,6 @@ struct InnerHttpServerStarter( ServerConnectionHandler, ); -type InnerHttpServerStarterNewReturn = - (InnerHttpServerStarter, Arc>, SocketAddr); - impl InnerHttpServerStarter { /// Begins execution of the underlying Http server. fn start( @@ -305,40 +195,6 @@ impl InnerHttpServerStarter { graceful.shutdown().await }) } - - /// Set up an HTTP server bound on the specified address that runs - /// registered handlers. You must invoke `start()` on the returned instance - /// of `HttpServerStarter` (and await the result) to actually start the - /// server. - fn new( - _config: &ConfigDropshot, - _server_config: ServerConfig, - _api: ApiDescription, - _private: C, - _log: &Logger, - _handler_waitgroup_worker: waitgroup::Worker, - ) -> Result, BuildError> { - todo!(); - // let incoming = - // HttpAcceptor { tcp, log: log.new(o!("local_addr" => local_addr)) }; - - // let app_state = Arc::new(DropshotState { - // private, - // config: server_config, - // router: api.into_router(), - // log: logger, - // local_addr, - // tls_acceptor: None, - // handler_waitgroup_worker: DebugIgnore(handler_waitgroup_worker), - // }); - - // let make_service = ServerConnectionHandler::new(app_state.clone()); - // Ok(( - // InnerHttpServerStarter(incoming, make_service), - // app_state, - // local_addr, - // )) - } } /// Accepts TCP connections like a `TcpListener`, but ignores transient errors rather than propagating them to the caller @@ -573,9 +429,6 @@ impl TryFrom<&ConfigTls> for rustls::ServerConfig { } } -type InnerHttpsServerStarterNewReturn = - (InnerHttpsServerStarter, Arc>, SocketAddr); - impl InnerHttpsServerStarter { /// Begins execution of the underlying Http server. fn start( @@ -620,43 +473,6 @@ impl InnerHttpsServerStarter { graceful.shutdown().await }) } - - fn new( - _config: &ConfigDropshot, - _server_config: ServerConfig, - _api: ApiDescription, - _private: C, - _log: &Logger, - _tls: &ConfigTls, - _handler_waitgroup_worker: waitgroup::Worker, - ) -> Result, BuildError> { - todo!(); - // let acceptor = Arc::new(Mutex::new(TlsAcceptor::from(Arc::new( - // rustls::ServerConfig::try_from(tls)?, - // )))); - - // let tcp = HttpAcceptor { tcp, log: logger.clone() }; - // let https_acceptor = - // HttpsAcceptor::new(logger.clone(), acceptor.clone(), tcp); - - // let app_state = Arc::new(DropshotState { - // private, - // config: server_config, - // router: api.into_router(), - // log: logger, - // local_addr, - // tls_acceptor: Some(acceptor), - // handler_waitgroup_worker: DebugIgnore(handler_waitgroup_worker), - // }); - - // let make_service = ServerConnectionHandler::new(Arc::clone(&app_state)); - - // Ok(( - // InnerHttpsServerStarter(https_acceptor, make_service), - // app_state, - // local_addr, - // )) - } } type SharedBoxFuture = Shared + Send>>>; From 32c26b2dc5b06b1e4141ae5d165aafa02d8490ee Mon Sep 17 00:00:00 2001 From: David Pacheco Date: Thu, 26 Sep 2024 21:28:41 -0700 Subject: [PATCH 4/4] remove more unnecessarily bifurcated code --- dropshot/src/server.rs | 207 +++++++++++++++-------------------------- 1 file changed, 76 insertions(+), 131 deletions(-) diff --git a/dropshot/src/server.rs b/dropshot/src/server.rs index 691225171..2bc59f0bb 100644 --- a/dropshot/src/server.rs +++ b/dropshot/src/server.rs @@ -140,63 +140,6 @@ impl HttpServerStarter { } } -enum WrappedHttpServerStarter { - Http(InnerHttpServerStarter), - Https(InnerHttpsServerStarter), -} - -struct InnerHttpServerStarter( - HttpAcceptor, - ServerConnectionHandler, -); - -impl InnerHttpServerStarter { - /// Begins execution of the underlying Http server. - fn start( - self, - mut close_signal: tokio::sync::oneshot::Receiver<()>, - log_close: Logger, - ) -> tokio::task::JoinHandle<()> { - use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; - use hyper_util::server::conn::auto; - - tokio::spawn(async move { - let mut builder = auto::Builder::new(TokioExecutor::new()); - // http/1 settings - builder.http1().timer(TokioTimer::new()); - // http/2 settings - builder.http2().timer(TokioTimer::new()); - - // Use a graceful watcher to keep track of all existing connections, - // and when the close_signal is trigger, force all known conns - // to start a graceful shutdown. - let graceful = - hyper_util::server::graceful::GracefulShutdown::new(); - - loop { - tokio::select! { - (sock, remote_addr) = self.0.accept() => { - let fut = builder.serve_connection_with_upgrades( - TokioIo::new(sock), - self.1.make_http_request_handler(remote_addr), - ); - let fut = graceful.watch(fut.into_owned()); - tokio::spawn(fut); - }, - - _ = &mut close_signal => { - info!(log_close, "received request to begin graceful shutdown"); - break; - } - } - } - - // optional: could use another select on a timeout - graceful.shutdown().await - }) - } -} - /// Accepts TCP connections like a `TcpListener`, but ignores transient errors rather than propagating them to the caller struct HttpAcceptor { tcp: TcpListener, @@ -362,11 +305,6 @@ impl HttpsAcceptor { } } -struct InnerHttpsServerStarter( - HttpsAcceptor, - ServerConnectionHandler, -); - /// Create a TLS configuration from the Dropshot config structure. impl TryFrom<&ConfigTls> for rustls::ServerConfig { type Error = BuildError; @@ -429,52 +367,6 @@ impl TryFrom<&ConfigTls> for rustls::ServerConfig { } } -impl InnerHttpsServerStarter { - /// Begins execution of the underlying Http server. - fn start( - mut self, - mut close_signal: tokio::sync::oneshot::Receiver<()>, - log_close: Logger, - ) -> tokio::task::JoinHandle<()> { - use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; - use hyper_util::server::conn::auto; - - tokio::spawn(async move { - let mut builder = auto::Builder::new(TokioExecutor::new()); - // http/1 settings - builder.http1().timer(TokioTimer::new()); - - // Use a graceful watcher to keep track of all existing connections, - // and when the close_signal is trigger, force all known conns - // to start a graceful shutdown. - let graceful = - hyper_util::server::graceful::GracefulShutdown::new(); - - loop { - tokio::select! { - Some(Ok(sock)) = self.0.accept() => { - let remote_addr = sock.remote_addr(); - let fut = builder.serve_connection_with_upgrades( - TokioIo::new(sock), - self.1.make_http_request_handler(remote_addr), - ); - let fut = graceful.watch(fut.into_owned()); - tokio::spawn(fut); - }, - - _ = &mut close_signal => { - info!(log_close, "received request to begin graceful shutdown"); - break; - } - } - } - - // optional: could use another select on a timeout - graceful.shutdown().await - }) - } -} - type SharedBoxFuture = Shared + Send>>>; /// Future returned by [`HttpServer::wait_for_shutdown()`]. @@ -1081,21 +973,6 @@ impl ServerBuilder { let incoming = HttpAcceptor { tcp, log: log.clone() }; - let inner_starter = match tls_acceptor { - Some(tls_acceptor) => { - let https_acceptor = - HttpsAcceptor::new(log.clone(), tls_acceptor, incoming); - WrappedHttpServerStarter::Https(InnerHttpsServerStarter( - https_acceptor, - make_service, - )) - } - None => WrappedHttpServerStarter::Http(InnerHttpServerStarter( - incoming, - make_service, - )), - }; - let log = &app_state.log; for (path, method, _) in &app_state.router { debug!(log, "registered endpoint"; @@ -1104,14 +981,82 @@ impl ServerBuilder { ); } - let (tx, rx) = tokio::sync::oneshot::channel::<()>(); - let log_close = app_state.log.new(o!()); - let join_handle = match inner_starter { - WrappedHttpServerStarter::Http(http) => http.start(rx, log_close), - WrappedHttpServerStarter::Https(https) => { - https.start(rx, log_close) - } - }; + let (tx, mut rx) = tokio::sync::oneshot::channel::<()>(); + + let log_close = log.clone(); + let join_handle = tokio::spawn(async move { + use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; + use hyper_util::server::conn::auto; + + let mut builder = auto::Builder::new(TokioExecutor::new()); + // http/1 settings + builder.http1().timer(TokioTimer::new()); + // XXX-dap previously, the TLS one did NOT do this http2 step + // http/2 settings + builder.http2().timer(TokioTimer::new()); + + // Use a graceful watcher to keep track of all existing connections, + // and when the close_signal is trigger, force all known conns + // to start a graceful shutdown. + let graceful = + hyper_util::server::graceful::GracefulShutdown::new(); + + // The following code looks superficially similar between the HTTP + // and HTTPS paths. However, the concrete types of various objects + // are different and so it's not easy to actually share the code. + let log = log_close; + match tls_acceptor { + Some(tls_acceptor) => { + let mut https_acceptor = + HttpsAcceptor::new(log.clone(), tls_acceptor, incoming); + loop { + tokio::select! { + Some(Ok(sock)) = https_acceptor.accept() => { + let remote_addr = sock.remote_addr(); + let handler = make_service + .make_http_request_handler(remote_addr); + let fut = builder + .serve_connection_with_upgrades( + TokioIo::new(sock), + handler, + ); + let fut = graceful.watch(fut.into_owned()); + tokio::spawn(fut); + }, + + _ = &mut rx => { + info!(log, "beginning graceful shutdown"); + break; + } + } + } + } + None => loop { + tokio::select! { + (sock, remote_addr) = incoming.accept() => { + let handler = make_service + .make_http_request_handler(remote_addr); + let fut = builder + .serve_connection_with_upgrades( + TokioIo::new(sock), + handler, + ); + let fut = graceful.watch(fut.into_owned()); + tokio::spawn(fut); + }, + + _ = &mut rx => { + info!(log, "beginning graceful shutdown"); + break; + } + } + }, + }; + + // optional: could use another select on a timeout + graceful.shutdown().await + }); + info!(log, "listening"); let join_handle = async move {