From 85a92f3da3c1b8b0d20c71f8d9a7063f98ab4eb9 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 19 Dec 2023 18:16:02 +0100 Subject: [PATCH] Support graceful shutdown on "auto conn" (#66) * remove references from `ReadVersion` * expose `fn graceful_shutdown(self: Pin<&mut Self>)` on Connection * support graceful shutdown on upgradeable connections * format * update to hyper 1.1.0 --- Cargo.toml | 4 +- src/server/conn/auto.rs | 326 ++++++++++++++++++++++++++++++++++------ 2 files changed, 281 insertions(+), 49 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 841c4b9..d4e6328 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ features = ["full"] rustdoc-args = ["--cfg", "docsrs"] [dependencies] -hyper = "1.0.0" +hyper = "1.1.0" futures-channel = "0.3" futures-util = { version = "0.3.16", default-features = false } http = "1.0" @@ -31,7 +31,7 @@ tower-service ={ version = "0.3", optional = true } tower = { version = "0.4.1", optional = true, features = ["make", "util"] } [dev-dependencies] -hyper = { version = "1.0.0", features = ["full"] } +hyper = { version = "1.1.0", features = ["full"] } bytes = "1" http-body-util = "0.1.0" tokio = { version = "1", features = ["macros", "test-util"] } diff --git a/src/server/conn/auto.rs b/src/server/conn/auto.rs index 2747191..ccd315d 100644 --- a/src/server/conn/auto.rs +++ b/src/server/conn/auto.rs @@ -1,9 +1,11 @@ //! Http1 or Http2 connection. use futures_util::ready; +use hyper::service::HttpService; use std::future::Future; use std::io::{Error as IoError, ErrorKind, Result as IoResult}; use std::marker::PhantomPinned; +use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; use std::{error::Error as StdError, marker::Unpin, time::Duration}; @@ -66,7 +68,7 @@ impl Builder { } /// Bind a connection together with a [`Service`]. - pub async fn serve_connection(&self, io: I, service: S) -> Result<()> + pub fn serve_connection(&self, io: I, service: S) -> Connection<'_, I, S, E> where S: Service, Response = Response>, S::Future: 'static, @@ -76,19 +78,23 @@ impl Builder { I: Read + Write + Unpin + 'static, E: Http2ServerConnExec, { - let (version, io) = read_version(io).await?; - match version { - Version::H1 => self.http1.serve_connection(io, service).await?, - Version::H2 => self.http2.serve_connection(io, service).await?, + Connection { + state: ConnState::ReadVersion { + read_version: read_version(io), + builder: self, + service: Some(service), + }, } - - Ok(()) } /// Bind a connection together with a [`Service`], with the ability to /// handle HTTP upgrades. This requires that the IO object implements /// `Send`. - pub async fn serve_connection_with_upgrades(&self, io: I, service: S) -> Result<()> + pub fn serve_connection_with_upgrades( + &self, + io: I, + service: S, + ) -> UpgradeableConnection<'_, I, S, E> where S: Service, Response = Response>, S::Future: 'static, @@ -98,18 +104,13 @@ impl Builder { I: Read + Write + Unpin + Send + 'static, E: Http2ServerConnExec, { - let (version, io) = read_version(io).await?; - match version { - Version::H1 => { - self.http1 - .serve_connection(io, service) - .with_upgrades() - .await? - } - Version::H2 => self.http2.serve_connection(io, service).await?, + UpgradeableConnection { + state: UpgradeableConnState::ReadVersion { + read_version: read_version(io), + builder: self, + service: Some(service), + }, } - - Ok(()) } } #[derive(Copy, Clone)] @@ -117,26 +118,26 @@ enum Version { H1, H2, } -async fn read_version<'a, A>(mut reader: A) -> IoResult<(Version, Rewind)> + +fn read_version(io: I) -> ReadVersion where - A: Read + Unpin, + I: Read + Unpin, { - use std::mem::MaybeUninit; - - let mut buf = [MaybeUninit::uninit(); 24]; - let (version, buf) = ReadVersion { - reader: &mut reader, - buf: ReadBuf::uninit(&mut buf), + ReadVersion { + io: Some(io), + buf: [MaybeUninit::uninit(); 24], + filled: 0, version: Version::H1, _pin: PhantomPinned, } - .await?; - Ok((version, Rewind::new_buffered(reader, Bytes::from(buf)))) } + pin_project! { - struct ReadVersion<'a, A: ?Sized> { - reader: &'a mut A, - buf: ReadBuf<'a>, + struct ReadVersion { + io: Option, + buf: [MaybeUninit; 24], + // the amount of `buf` thats been filled + filled: usize, version: Version, // Make this future `!Unpin` for compatibility with async trait methods. #[pin] @@ -144,30 +145,261 @@ pin_project! { } } -impl Future for ReadVersion<'_, A> +impl Future for ReadVersion where - A: Read + Unpin + ?Sized, + I: Read + Unpin, { - type Output = IoResult<(Version, Vec)>; + type Output = IoResult<(Version, Rewind)>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll)>> { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - while this.buf.filled().len() < H2_PREFACE.len() { - if this.buf.filled() != &H2_PREFACE[0..this.buf.filled().len()] { - return Poll::Ready(Ok((*this.version, this.buf.filled().to_vec()))); - } - // if our buffer is empty, then we need to read some data to continue. - let len = this.buf.filled().len(); - ready!(Pin::new(&mut *this.reader).poll_read(cx, this.buf.unfilled()))?; - if this.buf.filled().len() == len { - return Err(IoError::new(ErrorKind::UnexpectedEof, "early eof")).into(); + let mut buf = ReadBuf::uninit(&mut *this.buf); + // SAFETY: `this.filled` tracks how many bytes have been read (and thus initialized) and + // we're only advancing by that many. + unsafe { + buf.unfilled().advance(*this.filled); + }; + + while buf.filled().len() < H2_PREFACE.len() { + if buf.filled() != &H2_PREFACE[0..buf.filled().len()] { + let io = this.io.take().unwrap(); + let buf = buf.filled().to_vec(); + return Poll::Ready(Ok(( + *this.version, + Rewind::new_buffered(io, Bytes::from(buf)), + ))); + } else { + // if our buffer is empty, then we need to read some data to continue. + let len = buf.filled().len(); + ready!(Pin::new(this.io.as_mut().unwrap()).poll_read(cx, buf.unfilled()))?; + *this.filled = buf.filled().len(); + if buf.filled().len() == len { + return Err(IoError::new(ErrorKind::UnexpectedEof, "early eof")).into(); + } } } - if this.buf.filled() == H2_PREFACE { + if buf.filled() == H2_PREFACE { *this.version = Version::H2; } - return Poll::Ready(Ok((*this.version, this.buf.filled().to_vec()))); + let io = this.io.take().unwrap(); + let buf = buf.filled().to_vec(); + Poll::Ready(Ok(( + *this.version, + Rewind::new_buffered(io, Bytes::from(buf)), + ))) + } +} + +pin_project! { + /// Connection future. + pub struct Connection<'a, I, S, E> + where + S: HttpService, + { + #[pin] + state: ConnState<'a, I, S, E>, + } +} + +pin_project! { + #[project = ConnStateProj] + enum ConnState<'a, I, S, E> + where + S: HttpService, + { + ReadVersion { + #[pin] + read_version: ReadVersion, + builder: &'a Builder, + service: Option, + }, + H1 { + #[pin] + conn: hyper::server::conn::http1::Connection, S>, + }, + H2 { + #[pin] + conn: hyper::server::conn::http2::Connection, S, E>, + }, + } +} + +impl Connection<'_, I, S, E> +where + S: HttpService, + S::Error: Into>, + I: Read + Write + Unpin, + B: Body + 'static, + B::Error: Into>, + E: Http2ServerConnExec, +{ + /// Start a graceful shutdown process for this connection. + /// + /// This `Connection` should continue to be polled until shutdown can finish. + /// + /// # Note + /// + /// This should only be called while the `Connection` future is still pending. If called after + /// `Connection::poll` has resolved, this does nothing. + pub fn graceful_shutdown(self: Pin<&mut Self>) { + match self.project().state.project() { + ConnStateProj::ReadVersion { .. } => {} + ConnStateProj::H1 { conn } => conn.graceful_shutdown(), + ConnStateProj::H2 { conn } => conn.graceful_shutdown(), + } + } +} + +impl Future for Connection<'_, I, S, E> +where + S: Service, Response = Response>, + S::Future: 'static, + S::Error: Into>, + B: Body + 'static, + B::Error: Into>, + I: Read + Write + Unpin + 'static, + E: Http2ServerConnExec, +{ + type Output = Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + let mut this = self.as_mut().project(); + + match this.state.as_mut().project() { + ConnStateProj::ReadVersion { + read_version, + builder, + service, + } => { + let (version, io) = ready!(read_version.poll(cx))?; + let service = service.take().unwrap(); + match version { + Version::H1 => { + let conn = builder.http1.serve_connection(io, service); + this.state.set(ConnState::H1 { conn }); + } + Version::H2 => { + let conn = builder.http2.serve_connection(io, service); + this.state.set(ConnState::H2 { conn }); + } + } + } + ConnStateProj::H1 { conn } => { + return conn.poll(cx).map_err(Into::into); + } + ConnStateProj::H2 { conn } => { + return conn.poll(cx).map_err(Into::into); + } + } + } + } +} + +pin_project! { + /// Connection future. + pub struct UpgradeableConnection<'a, I, S, E> + where + S: HttpService, + { + #[pin] + state: UpgradeableConnState<'a, I, S, E>, + } +} + +pin_project! { + #[project = UpgradeableConnStateProj] + enum UpgradeableConnState<'a, I, S, E> + where + S: HttpService, + { + ReadVersion { + #[pin] + read_version: ReadVersion, + builder: &'a Builder, + service: Option, + }, + H1 { + #[pin] + conn: hyper::server::conn::http1::UpgradeableConnection, S>, + }, + H2 { + #[pin] + conn: hyper::server::conn::http2::Connection, S, E>, + }, + } +} + +impl UpgradeableConnection<'_, I, S, E> +where + S: HttpService, + S::Error: Into>, + I: Read + Write + Unpin, + B: Body + 'static, + B::Error: Into>, + E: Http2ServerConnExec, +{ + /// Start a graceful shutdown process for this connection. + /// + /// This `UpgradeableConnection` should continue to be polled until shutdown can finish. + /// + /// # Note + /// + /// This should only be called while the `Connection` future is still nothing. pending. If + /// called after `UpgradeableConnection::poll` has resolved, this does nothing. + pub fn graceful_shutdown(self: Pin<&mut Self>) { + match self.project().state.project() { + UpgradeableConnStateProj::ReadVersion { .. } => {} + UpgradeableConnStateProj::H1 { conn } => conn.graceful_shutdown(), + UpgradeableConnStateProj::H2 { conn } => conn.graceful_shutdown(), + } + } +} + +impl Future for UpgradeableConnection<'_, I, S, E> +where + S: Service, Response = Response>, + S::Future: 'static, + S::Error: Into>, + B: Body + 'static, + B::Error: Into>, + I: Read + Write + Unpin + Send + 'static, + E: Http2ServerConnExec, +{ + type Output = Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + let mut this = self.as_mut().project(); + + match this.state.as_mut().project() { + UpgradeableConnStateProj::ReadVersion { + read_version, + builder, + service, + } => { + let (version, io) = ready!(read_version.poll(cx))?; + let service = service.take().unwrap(); + match version { + Version::H1 => { + let conn = builder.http1.serve_connection(io, service).with_upgrades(); + this.state.set(UpgradeableConnState::H1 { conn }); + } + Version::H2 => { + let conn = builder.http2.serve_connection(io, service); + this.state.set(UpgradeableConnState::H2 { conn }); + } + } + } + UpgradeableConnStateProj::H1 { conn } => { + return conn.poll(cx).map_err(Into::into); + } + UpgradeableConnStateProj::H2 { conn } => { + return conn.poll(cx).map_err(Into::into); + } + } + } } }