diff --git a/Cargo.toml b/Cargo.toml index 9e7758d..0cc6af4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ tower = { version = "0.4.1", optional = true, default-features = false, features hyper = { version = "1.3.0", features = ["full"] } bytes = "1" http-body-util = "0.1.0" -tokio = { version = "1", features = ["macros", "test-util"] } +tokio = { version = "1", features = ["macros", "test-util", "signal"] } tokio-test = "0.4" pretty_env_logger = "0.5" @@ -51,6 +51,7 @@ full = [ "client-legacy", "server", "server-auto", + "server-graceful", "service", "http1", "http2", @@ -62,6 +63,7 @@ client-legacy = ["client", "dep:socket2", "tokio/sync"] server = ["hyper/server"] server-auto = ["server", "http1", "http2"] +server-graceful = ["server", "tokio/sync"] service = ["dep:tower", "dep:tower-service"] @@ -80,3 +82,7 @@ required-features = ["client-legacy", "http1", "tokio"] [[example]] name = "server" required-features = ["server", "http1", "tokio"] + +[[example]] +name = "server_graceful" +required-features = ["tokio", "server-graceful", "server-auto"] diff --git a/examples/server_graceful.rs b/examples/server_graceful.rs new file mode 100644 index 0000000..bfb43a4 --- /dev/null +++ b/examples/server_graceful.rs @@ -0,0 +1,64 @@ +use bytes::Bytes; +use std::convert::Infallible; +use std::pin::pin; +use std::time::Duration; +use tokio::net::TcpListener; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let listener = TcpListener::bind("127.0.0.1:8080").await?; + + let server = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()); + let graceful = hyper_util::server::graceful::GracefulShutdown::new(); + let mut ctrl_c = pin!(tokio::signal::ctrl_c()); + + loop { + tokio::select! { + conn = listener.accept() => { + let (stream, peer_addr) = match conn { + Ok(conn) => conn, + Err(e) => { + eprintln!("accept error: {}", e); + tokio::time::sleep(Duration::from_secs(1)).await; + continue; + } + }; + eprintln!("incomming connection accepted: {}", peer_addr); + + let stream = hyper_util::rt::TokioIo::new(Box::pin(stream)); + + let conn = server.serve_connection_with_upgrades(stream, hyper::service::service_fn(|_| async move { + tokio::time::sleep(Duration::from_secs(5)).await; // emulate slow request + let body = http_body_util::Full::::from("Hello World!".to_owned()); + Ok::<_, Infallible>(http::Response::new(body)) + })); + + let conn = graceful.watch(conn.into_owned()); + + tokio::spawn(async move { + if let Err(err) = conn.await { + eprintln!("connection error: {}", err); + } + eprintln!("connection dropped: {}", peer_addr); + }); + }, + + _ = ctrl_c.as_mut() => { + drop(listener); + eprintln!("Ctrl-C received, starting shutdown"); + break; + } + } + } + + tokio::select! { + _ = graceful.shutdown() => { + eprintln!("Gracefully shutdown!"); + }, + _ = tokio::time::sleep(Duration::from_secs(10)) => { + eprintln!("Waited 10 seconds for graceful shutdown, aborting..."); + } + } + + Ok(()) +} diff --git a/src/server/graceful.rs b/src/server/graceful.rs new file mode 100644 index 0000000..fc92d3b --- /dev/null +++ b/src/server/graceful.rs @@ -0,0 +1,446 @@ +//! Utility to gracefully shutdown a server. +//! +//! This module provides a [`GracefulShutdown`] type, +//! which can be used to gracefully shutdown a server. +//! +//! See +//! for an example of how to use this. + +use std::{ + fmt::{self, Debug}, + future::Future, + pin::Pin, + task::{self, Poll}, +}; + +use pin_project_lite::pin_project; +use tokio::sync::watch; + +/// A graceful shutdown utility +pub struct GracefulShutdown { + tx: watch::Sender<()>, + rx: watch::Receiver<()>, +} + +impl GracefulShutdown { + /// Create a new graceful shutdown helper. + pub fn new() -> Self { + let (tx, rx) = watch::channel(()); + Self { tx, rx } + } + + /// Wrap a future for graceful shutdown watching. + pub fn watch(&self, conn: C) -> impl Future { + let mut rx = self.rx.clone(); + GracefulConnectionFuture::new(conn, async move { + let _ = rx.changed().await; + // hold onto the rx until the watched future is completed + rx + }) + } + + /// Signal shutdown for all watched connections. + /// + /// This returns a `Future` which will complete once all watched + /// connections have shutdown. + pub async fn shutdown(self) { + // drop the rx immediately, or else it will hold us up + let Self { tx, rx } = self; + drop(rx); + + // signal all the watched futures about the change + let _ = tx.send(()); + // and then wait for all of them to complete + tx.closed().await; + } +} + +impl Debug for GracefulShutdown { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GracefulShutdown").finish() + } +} + +impl Default for GracefulShutdown { + fn default() -> Self { + Self::new() + } +} + +pin_project! { + struct GracefulConnectionFuture { + #[pin] + conn: C, + #[pin] + cancel: F, + #[pin] + // If cancelled, this is held until the inner conn is done. + cancelled_guard: Option, + } +} + +impl GracefulConnectionFuture { + fn new(conn: C, cancel: F) -> Self { + Self { + conn, + cancel, + cancelled_guard: None, + } + } +} + +impl Debug for GracefulConnectionFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GracefulConnectionFuture").finish() + } +} + +impl Future for GracefulConnectionFuture +where + C: GracefulConnection, + F: Future, +{ + type Output = C::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + let mut this = self.project(); + if this.cancelled_guard.is_none() { + if let Poll::Ready(guard) = this.cancel.poll(cx) { + this.cancelled_guard.set(Some(guard)); + this.conn.as_mut().graceful_shutdown(); + } + } + this.conn.poll(cx) + } +} + +/// An internal utility trait as an umbrella target for all (hyper) connection +/// types that the [`GracefulShutdown`] can watch. +pub trait GracefulConnection: Future> + private::Sealed { + /// The error type returned by the connection when used as a future. + type Error; + + /// Start a graceful shutdown process for this connection. + fn graceful_shutdown(self: Pin<&mut Self>); +} + +#[cfg(feature = "http1")] +impl GracefulConnection for hyper::server::conn::http1::Connection +where + S: hyper::service::HttpService, + S::Error: Into>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, +{ + type Error = hyper::Error; + + fn graceful_shutdown(self: Pin<&mut Self>) { + hyper::server::conn::http1::Connection::graceful_shutdown(self); + } +} + +#[cfg(feature = "http2")] +impl GracefulConnection for hyper::server::conn::http2::Connection +where + S: hyper::service::HttpService, + S::Error: Into>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, +{ + type Error = hyper::Error; + + fn graceful_shutdown(self: Pin<&mut Self>) { + hyper::server::conn::http2::Connection::graceful_shutdown(self); + } +} + +#[cfg(feature = "server-auto")] +impl<'a, I, B, S, E> GracefulConnection for crate::server::conn::auto::Connection<'a, I, S, E> +where + S: hyper::service::Service, Response = http::Response>, + S::Error: Into>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, +{ + type Error = Box; + + fn graceful_shutdown(self: Pin<&mut Self>) { + crate::server::conn::auto::Connection::graceful_shutdown(self); + } +} + +#[cfg(feature = "server-auto")] +impl<'a, I, B, S, E> GracefulConnection + for crate::server::conn::auto::UpgradeableConnection<'a, I, S, E> +where + S: hyper::service::Service, Response = http::Response>, + S::Error: Into>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, +{ + type Error = Box; + + fn graceful_shutdown(self: Pin<&mut Self>) { + crate::server::conn::auto::UpgradeableConnection::graceful_shutdown(self); + } +} + +mod private { + pub trait Sealed {} + + #[cfg(feature = "http1")] + impl Sealed for hyper::server::conn::http1::Connection + where + S: hyper::service::HttpService, + S::Error: Into>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + { + } + + #[cfg(feature = "http1")] + impl Sealed for hyper::server::conn::http1::UpgradeableConnection + where + S: hyper::service::HttpService, + S::Error: Into>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + { + } + + #[cfg(feature = "http2")] + impl Sealed for hyper::server::conn::http2::Connection + where + S: hyper::service::HttpService, + S::Error: Into>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, + { + } + + #[cfg(feature = "server-auto")] + impl<'a, I, B, S, E> Sealed for crate::server::conn::auto::Connection<'a, I, S, E> + where + S: hyper::service::Service< + http::Request, + Response = http::Response, + >, + S::Error: Into>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, + { + } + + #[cfg(feature = "server-auto")] + impl<'a, I, B, S, E> Sealed for crate::server::conn::auto::UpgradeableConnection<'a, I, S, E> + where + S: hyper::service::Service< + http::Request, + Response = http::Response, + >, + S::Error: Into>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, + { + } +} + +#[cfg(test)] +mod test { + use super::*; + use pin_project_lite::pin_project; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + pin_project! { + #[derive(Debug)] + struct DummyConnection { + #[pin] + future: F, + shutdown_counter: Arc, + } + } + + impl private::Sealed for DummyConnection {} + + impl GracefulConnection for DummyConnection { + type Error = (); + + fn graceful_shutdown(self: Pin<&mut Self>) { + self.shutdown_counter.fetch_add(1, Ordering::SeqCst); + } + } + + impl Future for DummyConnection { + type Output = Result<(), ()>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + match self.project().future.poll(cx) { + Poll::Ready(_) => Poll::Ready(Ok(())), + Poll::Pending => Poll::Pending, + } + } + } + + #[tokio::test] + async fn test_graceful_shutdown_ok() { + let graceful = GracefulShutdown::new(); + let shutdown_counter = Arc::new(AtomicUsize::new(0)); + let (dummy_tx, _) = tokio::sync::broadcast::channel(1); + + for i in 1..=3 { + let mut dummy_rx = dummy_tx.subscribe(); + let shutdown_counter = shutdown_counter.clone(); + + let future = async move { + tokio::time::sleep(std::time::Duration::from_millis(i * 10)).await; + let _ = dummy_rx.recv().await; + }; + let dummy_conn = DummyConnection { + future, + shutdown_counter, + }; + let conn = graceful.watch(dummy_conn); + tokio::spawn(async move { + conn.await.unwrap(); + }); + } + + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); + let _ = dummy_tx.send(()); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => { + panic!("timeout") + }, + _ = graceful.shutdown() => { + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3); + } + } + } + + #[tokio::test] + async fn test_graceful_shutdown_delayed_ok() { + let graceful = GracefulShutdown::new(); + let shutdown_counter = Arc::new(AtomicUsize::new(0)); + + for i in 1..=3 { + let shutdown_counter = shutdown_counter.clone(); + + //tokio::time::sleep(std::time::Duration::from_millis(i * 5)).await; + let future = async move { + tokio::time::sleep(std::time::Duration::from_millis(i * 50)).await; + }; + let dummy_conn = DummyConnection { + future, + shutdown_counter, + }; + let conn = graceful.watch(dummy_conn); + tokio::spawn(async move { + conn.await.unwrap(); + }); + } + + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => { + panic!("timeout") + }, + _ = graceful.shutdown() => { + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3); + } + } + } + + #[tokio::test] + async fn test_graceful_shutdown_multi_per_watcher_ok() { + let graceful = GracefulShutdown::new(); + let shutdown_counter = Arc::new(AtomicUsize::new(0)); + + for i in 1..=3 { + let shutdown_counter = shutdown_counter.clone(); + + let mut futures = Vec::new(); + for u in 1..=i { + let future = tokio::time::sleep(std::time::Duration::from_millis(u * 50)); + let dummy_conn = DummyConnection { + future, + shutdown_counter: shutdown_counter.clone(), + }; + let conn = graceful.watch(dummy_conn); + futures.push(conn); + } + tokio::spawn(async move { + futures_util::future::join_all(futures).await; + }); + } + + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => { + panic!("timeout") + }, + _ = graceful.shutdown() => { + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 6); + } + } + } + + #[tokio::test] + async fn test_graceful_shutdown_timeout() { + let graceful = GracefulShutdown::new(); + let shutdown_counter = Arc::new(AtomicUsize::new(0)); + + for i in 1..=3 { + let shutdown_counter = shutdown_counter.clone(); + + let future = async move { + if i == 1 { + std::future::pending::<()>().await + } else { + std::future::ready(()).await + } + }; + let dummy_conn = DummyConnection { + future, + shutdown_counter, + }; + let conn = graceful.watch(dummy_conn); + tokio::spawn(async move { + conn.await.unwrap(); + }); + } + + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => { + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3); + }, + _ = graceful.shutdown() => { + panic!("shutdown should not be completed: as not all our conns finish") + } + } + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 7b4515c..a4838ac 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,3 +1,6 @@ //! Server utilities. pub mod conn; + +#[cfg(feature = "server-graceful")] +pub mod graceful;