From 15ad1ee45cc07188329383c36bff0d8067029570 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1niel=20Buga?= Date: Tue, 17 Oct 2023 16:51:42 +0200 Subject: [PATCH] Auto-buffer HTTP connections when TLS is set up --- src/client.rs | 44 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/src/client.rs b/src/client.rs index f8c203e..5ba5ef7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -119,6 +119,15 @@ where #[cfg(not(feature = "embedded-tls"))] Err(Error::InvalidUrl(nourl::Error::UnsupportedScheme)) } else { + #[cfg(feature = "embedded-tls")] + match self.tls.as_mut() { + Some(tls) => Ok(HttpConnection::PlainBuffered(BufferedWrite::new( + buffered_io_adapter::ConnErrorAdapter(conn), + tls.write_buffer, + ))), + None => Ok(HttpConnection::Plain(conn)), + } + #[cfg(not(feature = "embedded-tls"))] Ok(HttpConnection::Plain(conn)) } } @@ -155,15 +164,17 @@ where /// Represents a HTTP connection that may be encrypted or unencrypted. #[allow(clippy::large_enum_variant)] -pub enum HttpConnection<'m, T> +pub enum HttpConnection<'m, C> where - T: Read + Write, + C: Read + Write, { - Plain(T), + Plain(C), + #[cfg(feature = "embedded-tls")] + PlainBuffered(BufferedWrite<'m, buffered_io_adapter::ConnErrorAdapter>), #[cfg(feature = "embedded-tls")] - Tls(embedded_tls::TlsConnection<'m, T, embedded_tls::Aes128GcmSha256>), + Tls(embedded_tls::TlsConnection<'m, C, embedded_tls::Aes128GcmSha256>), #[cfg(not(feature = "embedded-tls"))] - Tls(&'m mut T), // Variant is never actually created, but we need it to avoid "unused lifetime" warning + Tls((&'m mut (), core::convert::Infallible)), // Variant is impossible to create, but we need it to avoid "unused lifetime" warning } impl<'conn, T> HttpConnection<'conn, T> @@ -200,7 +211,12 @@ where async fn read(&mut self, buf: &mut [u8]) -> Result { match self { Self::Plain(conn) => conn.read(buf).await.map_err(|e| e.kind()), + #[cfg(feature = "embedded-tls")] + Self::PlainBuffered(conn) => conn.read(buf).await.map_err(|e| e.kind()), + #[cfg(feature = "embedded-tls")] Self::Tls(conn) => conn.read(buf).await.map_err(|e| e.kind()), + #[cfg(not(feature = "embedded-tls"))] + _ => unreachable!(), } } } @@ -212,14 +228,24 @@ where async fn write(&mut self, buf: &[u8]) -> Result { match self { Self::Plain(conn) => conn.write(buf).await.map_err(|e| e.kind()), + #[cfg(feature = "embedded-tls")] + Self::PlainBuffered(conn) => conn.write(buf).await.map_err(|e| e.kind()), + #[cfg(feature = "embedded-tls")] Self::Tls(conn) => conn.write(buf).await.map_err(|e| e.kind()), + #[cfg(not(feature = "embedded-tls"))] + _ => unreachable!(), } } async fn flush(&mut self) -> Result<(), Self::Error> { match self { Self::Plain(conn) => conn.flush().await.map_err(|e| e.kind()), + #[cfg(feature = "embedded-tls")] + Self::PlainBuffered(conn) => conn.flush().await.map_err(|e| e.kind()), + #[cfg(feature = "embedded-tls")] Self::Tls(conn) => conn.flush().await.map_err(|e| e.kind()), + #[cfg(not(feature = "embedded-tls"))] + _ => unreachable!(), } } } @@ -241,9 +267,10 @@ where C: Read + Write, B: RequestBody, { - /// Turn the request into a buffered request + /// Turn the request into a buffered request. /// - /// This is most likely only relevant for non-tls endpoints, as `embedded-tls` buffers internally. + /// This is only relevant if no TLS is used, as `embedded-tls` buffers internally and we reuse + /// its buffer for non-TLS connections. pub fn into_buffered<'buf>( self, tx_buf: &'buf mut [u8], @@ -328,7 +355,8 @@ where { /// Turn the resource into a buffered resource /// - /// This is most likely only relevant for non-tls endpoints, as `embedded-tls` buffers internally. + /// This is only relevant if no TLS is used, as `embedded-tls` buffers internally and we reuse + /// its buffer for non-TLS connections. pub fn into_buffered<'buf>( self, tx_buf: &'buf mut [u8],