From b29fc463b9f1b5729ef77f09c31f7a78111f7faa Mon Sep 17 00:00:00 2001 From: dswij Date: Sun, 17 Mar 2024 20:47:53 +0800 Subject: [PATCH] feat: add `{http1,http2}_only` for auto conn --- src/server/conn/auto.rs | 147 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 143 insertions(+), 4 deletions(-) diff --git a/src/server/conn/auto.rs b/src/server/conn/auto.rs index 7e74d16..148ed23 100644 --- a/src/server/conn/auto.rs +++ b/src/server/conn/auto.rs @@ -58,6 +58,8 @@ pub struct Builder { http1: http1::Builder, #[cfg(feature = "http2")] http2: http2::Builder, + #[cfg(any(feature = "http1", feature = "http2"))] + version: Option, #[cfg(not(feature = "http2"))] _executor: E, } @@ -84,6 +86,8 @@ impl Builder { http1: http1::Builder::new(), #[cfg(feature = "http2")] http2: http2::Builder::new(executor), + #[cfg(any(feature = "http1", feature = "http2"))] + version: None, #[cfg(not(feature = "http2"))] _executor: executor, } @@ -101,6 +105,26 @@ impl Builder { Http2Builder { inner: self } } + /// Only accepts HTTP/2 + /// + /// Does not do anything if used with [`serve_connection_with_upgrades`] + #[cfg(feature = "http2")] + pub fn http2_only(mut self) -> Self { + assert!(self.version.is_none()); + self.version = Some(Version::H2); + self + } + + /// Only accepts HTTP/1 + /// + /// Does not do anything if used with [`serve_connection_with_upgrades`] + #[cfg(feature = "http1")] + pub fn http1_only(mut self) -> Self { + assert!(self.version.is_none()); + self.version = Some(Version::H1); + self + } + /// Bind a connection together with a [`Service`]. pub fn serve_connection(&self, io: I, service: S) -> Connection<'_, I, S, E> where @@ -112,13 +136,28 @@ impl Builder { I: Read + Write + Unpin + 'static, E: HttpServerConnExec, { - Connection { - state: ConnState::ReadVersion { + let state = match self.version { + #[cfg(feature = "http1")] + Some(Version::H1) => { + let io = Rewind::new_buffered(io, Bytes::new()); + let conn = self.http1.serve_connection(io, service); + ConnState::H1 { conn } + } + #[cfg(feature = "http2")] + Some(Version::H2) => { + let io = Rewind::new_buffered(io, Bytes::new()); + let conn = self.http2.serve_connection(io, service); + ConnState::H2 { conn } + } + #[cfg(any(feature = "http1", feature = "http2"))] + _ => ConnState::ReadVersion { read_version: read_version(io), builder: self, service: Some(service), }, - } + }; + + Connection { state } } /// Bind a connection together with a [`Service`], with the ability to @@ -148,7 +187,7 @@ impl Builder { } } -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] enum Version { H1, H2, @@ -894,6 +933,62 @@ mod tests { assert_eq!(body, BODY); } + #[cfg(not(miri))] + #[tokio::test] + async fn http2_only() { + let addr = start_server_h2_only().await; + let mut sender = connect_h2(addr).await; + + let response = sender + .send_request(Request::new(Empty::::new())) + .await + .unwrap(); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + + assert_eq!(body, BODY); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn http2_only_fail_if_client_is_http1() { + let addr = start_server_h2_only().await; + let mut sender = connect_h1(addr).await; + + let _ = sender + .send_request(Request::new(Empty::::new())) + .await + .expect_err("should fail"); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn http1_only() { + let addr = start_server_h1_only().await; + let mut sender = connect_h1(addr).await; + + let response = sender + .send_request(Request::new(Empty::::new())) + .await + .unwrap(); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + + assert_eq!(body, BODY); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn http1_only_fail_if_client_is_http2() { + let addr = start_server_h1_only().await; + let mut sender = connect_h2(addr).await; + + let _ = sender + .send_request(Request::new(Empty::::new())) + .await + .expect_err("should fail"); + } + #[cfg(not(miri))] #[tokio::test] async fn graceful_shutdown() { @@ -980,6 +1075,50 @@ mod tests { local_addr } + async fn start_server_h2_only() -> SocketAddr { + let addr: SocketAddr = ([127, 0, 0, 1], 0).into(); + let listener = TcpListener::bind(addr).await.unwrap(); + + let local_addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + loop { + let (stream, _) = listener.accept().await.unwrap(); + let stream = TokioIo::new(stream); + tokio::task::spawn(async move { + let _ = auto::Builder::new(TokioExecutor::new()) + .http2_only() + .serve_connection(stream, service_fn(hello)) + .await; + }); + } + }); + + local_addr + } + + async fn start_server_h1_only() -> SocketAddr { + let addr: SocketAddr = ([127, 0, 0, 1], 0).into(); + let listener = TcpListener::bind(addr).await.unwrap(); + + let local_addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + loop { + let (stream, _) = listener.accept().await.unwrap(); + let stream = TokioIo::new(stream); + tokio::task::spawn(async move { + let _ = auto::Builder::new(TokioExecutor::new()) + .http1_only() + .serve_connection(stream, service_fn(hello)) + .await; + }); + } + }); + + local_addr + } + async fn hello(_req: Request) -> Result>, Infallible> { Ok(Response::new(Full::new(Bytes::from(BODY)))) }