diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7f6bfe468..9f9c4fbdf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -298,6 +298,7 @@ jobs: cargo update -p tokio --precise 1.29.1 cargo update -p tokio-util --precise 0.7.11 cargo update -p idna_adapter --precise 1.1.0 + cargo update -p hashbrown@0.15.2 --precise 0.15.0 - uses: Swatinem/rust-cache@v2 diff --git a/Cargo.toml b/Cargo.toml index 39ff48424..1a0c4abf6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ authors = ["Sean McArthur "] readme = "README.md" license = "MIT OR Apache-2.0" edition = "2021" -rust-version = "1.63.0" +rust-version = "1.64.0" autotests = true [package.metadata.docs.rs] @@ -105,6 +105,7 @@ url = "2.4" bytes = "1.0" serde = "1.0" serde_urlencoded = "0.7.1" +tower = { version = "0.5.2", default-features = false, features = ["timeout", "util"] } tower-service = "0.3" futures-core = { version = "0.3.28", default-features = false } futures-util = { version = "0.3.28", default-features = false } @@ -169,7 +170,6 @@ quinn = { version = "0.11.1", default-features = false, features = ["rustls", "r slab = { version = "0.4.9", optional = true } # just to get minimal versions working with quinn futures-channel = { version = "0.3", optional = true } - [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] env_logger = "0.10" hyper = { version = "1.1.0", default-features = false, features = ["http1", "http2", "client", "server"] } @@ -222,6 +222,11 @@ features = [ wasm-bindgen = { version = "0.2.89", features = ["serde-serialize"] } wasm-bindgen-test = "0.3" +[dev-dependencies] +tower = { version = "0.5.2", default-features = false, features = ["limit"] } +num_cpus = "1.0" +libc = "0" + [lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ['cfg(reqwest_unstable)'] } @@ -253,6 +258,10 @@ path = "examples/form.rs" name = "simple" path = "examples/simple.rs" +[[example]] +name = "connect_via_lower_priority_tokio_runtime" +path = "examples/connect_via_lower_priority_tokio_runtime.rs" + [[test]] name = "blocking" path = "tests/blocking.rs" diff --git a/README.md b/README.md index 1f8dbcb1f..b0b0eb813 100644 --- a/README.md +++ b/README.md @@ -53,13 +53,14 @@ On Linux: - OpenSSL with headers. See https://docs.rs/openssl for supported versions and more details. Alternatively you can enable the `native-tls-vendored` - feature to compile a copy of OpenSSL. + feature to compile a copy of OpenSSL. Or, you can use [rustls](https://github.com/rustls/rustls) + via `rustls-tls` or other `rustls-tls-*` features. On Windows and macOS: - Nothing. -Reqwest uses [rust-native-tls](https://github.com/sfackler/rust-native-tls), +By default, Reqwest uses [rust-native-tls](https://github.com/sfackler/rust-native-tls), which will use the operating system TLS framework if available, meaning Windows and macOS. On Linux, it will use the available OpenSSL or fail to build if not found. diff --git a/examples/connect_via_lower_priority_tokio_runtime.rs b/examples/connect_via_lower_priority_tokio_runtime.rs new file mode 100644 index 000000000..33151d4a1 --- /dev/null +++ b/examples/connect_via_lower_priority_tokio_runtime.rs @@ -0,0 +1,264 @@ +#![deny(warnings)] +// This example demonstrates how to delegate the connect calls, which contain TLS handshakes, +// to a secondary tokio runtime of lower OS thread priority using a custom tower layer. +// This helps to ensure that long-running futures during handshake crypto operations don't block other I/O futures. +// +// This does introduce overhead of additional threads, channels, extra vtables, etc, +// so it is best suited to services with large numbers of incoming connections or that +// are otherwise very sensitive to any blocking futures. Or, you might want fewer threads +// and/or to use the current_thread runtime. +// +// This is using the `tokio` runtime and certain other dependencies: +// +// `tokio = { version = "1", features = ["full"] }` +// `num_cpus = "1.0"` +// `libc = "0"` +// `pin-project-lite = "0.2"` +// `tower = { version = "0.5", default-features = false}` + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::main] +async fn main() -> Result<(), reqwest::Error> { + background_threadpool::init_background_runtime(); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + let client = reqwest::Client::builder() + .connector_layer(background_threadpool::BackgroundProcessorLayer::new()) + .build() + .expect("should be able to build reqwest client"); + + let url = if let Some(url) = std::env::args().nth(1) { + url + } else { + println!("No CLI URL provided, using default."); + "https://hyper.rs".into() + }; + + eprintln!("Fetching {url:?}..."); + + let res = client.get(url).send().await?; + + eprintln!("Response: {:?} {}", res.version(), res.status()); + eprintln!("Headers: {:#?}\n", res.headers()); + + let body = res.text().await?; + + println!("{body}"); + + Ok(()) +} + +// separating out for convenience to avoid a million #[cfg(not(target_arch = "wasm32"))] +#[cfg(not(target_arch = "wasm32"))] +mod background_threadpool { + use std::{ + future::Future, + pin::Pin, + sync::OnceLock, + task::{Context, Poll}, + }; + + use futures_util::TryFutureExt; + use pin_project_lite::pin_project; + use tokio::{runtime::Handle, select, sync::mpsc::error::TrySendError}; + use tower::{BoxError, Layer, Service}; + + static CPU_HEAVY_THREAD_POOL: OnceLock< + tokio::sync::mpsc::Sender + Send + 'static>>>, + > = OnceLock::new(); + + pub(crate) fn init_background_runtime() { + std::thread::Builder::new() + .name("cpu-heavy-background-threadpool".to_string()) + .spawn(move || { + let rt = tokio::runtime::Builder::new_multi_thread() + .thread_name("cpu-heavy-background-pool-thread") + .worker_threads(num_cpus::get() as usize) + // ref: https://github.com/tokio-rs/tokio/issues/4941 + // consider uncommenting if seeing heavy task contention + // .disable_lifo_slot() + .on_thread_start(move || { + #[cfg(target_os = "linux")] + unsafe { + // Increase thread pool thread niceness, so they are lower priority + // than the foreground executor and don't interfere with I/O tasks + { + *libc::__errno_location() = 0; + if libc::nice(10) == -1 && *libc::__errno_location() != 0 { + let error = std::io::Error::last_os_error(); + log::error!("failed to set threadpool niceness: {}", error); + } + } + } + }) + .enable_all() + .build() + .unwrap_or_else(|e| panic!("cpu heavy runtime failed_to_initialize: {}", e)); + rt.block_on(async { + log::debug!("starting background cpu-heavy work"); + process_cpu_work().await; + }); + }) + .unwrap_or_else(|e| panic!("cpu heavy thread failed_to_initialize: {}", e)); + } + + #[cfg(not(target_arch = "wasm32"))] + async fn process_cpu_work() { + // we only use this channel for routing work, it should move pretty quick, it can be small + let (tx, mut rx) = tokio::sync::mpsc::channel(10); + // share the handle to the background channel globally + CPU_HEAVY_THREAD_POOL.set(tx).unwrap(); + + while let Some(work) = rx.recv().await { + tokio::task::spawn(work); + } + } + + // retrieve the sender to the background channel, and send the future over to it for execution + fn send_to_background_runtime(future: impl Future + Send + 'static) { + let tx = CPU_HEAVY_THREAD_POOL.get().expect( + "start up the secondary tokio runtime before sending to `CPU_HEAVY_THREAD_POOL`", + ); + + match tx.try_send(Box::pin(future)) { + Ok(_) => (), + Err(TrySendError::Closed(_)) => { + panic!("background cpu heavy runtime channel is closed") + } + Err(TrySendError::Full(msg)) => { + log::warn!( + "background cpu heavy runtime channel is full, task spawning loop delayed" + ); + let tx = tx.clone(); + Handle::current().spawn(async move { + tx.send(msg) + .await + .expect("background cpu heavy runtime channel is closed") + }); + } + } + } + + // This tower layer injects futures with a oneshot channel, and then sends them to the background runtime for processing. + // We don't use the Buffer service because that is intended to process sequentially on a single task, whereas we want to + // spawn a new task per call. + #[derive(Copy, Clone)] + pub struct BackgroundProcessorLayer {} + impl BackgroundProcessorLayer { + pub fn new() -> Self { + Self {} + } + } + impl Layer for BackgroundProcessorLayer { + type Service = BackgroundProcessor; + fn layer(&self, service: S) -> Self::Service { + BackgroundProcessor::new(service) + } + } + + impl std::fmt::Debug for BackgroundProcessorLayer { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("BackgroundProcessorLayer").finish() + } + } + + // This tower service injects futures with a oneshot channel, and then sends them to the background runtime for processing. + #[derive(Debug, Clone)] + pub struct BackgroundProcessor { + inner: S, + } + + impl BackgroundProcessor { + pub fn new(inner: S) -> Self { + BackgroundProcessor { inner } + } + } + + impl Service for BackgroundProcessor + where + S: Service, + S::Response: Send + 'static, + S::Error: Into + Send, + S::Future: Send + 'static, + { + type Response = S::Response; + + type Error = BoxError; + + type Future = BackgroundResponseFuture; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.inner.poll_ready(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)), + } + } + + fn call(&mut self, req: Request) -> Self::Future { + let response = self.inner.call(req); + + // wrap our inner service's future with a future that writes to this oneshot channel + let (mut tx, rx) = tokio::sync::oneshot::channel(); + let future = async move { + select!( + _ = tx.closed() => { + // receiver already dropped, don't need to do anything + } + result = response.map_err(|err| Into::::into(err)) => { + // if this fails, the receiver already dropped, so we don't need to do anything + let _ = tx.send(result); + } + ) + }; + // send the wrapped future to the background + send_to_background_runtime(future); + + BackgroundResponseFuture::new(rx) + } + } + + // `BackgroundProcessor` response future + pin_project! { + #[derive(Debug)] + pub struct BackgroundResponseFuture { + #[pin] + rx: tokio::sync::oneshot::Receiver>, + } + } + + impl BackgroundResponseFuture { + pub(crate) fn new(rx: tokio::sync::oneshot::Receiver>) -> Self { + BackgroundResponseFuture { rx } + } + } + + impl Future for BackgroundResponseFuture + where + S: Send + 'static, + { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + // now poll on the receiver end of the oneshot to get the result + match this.rx.poll(cx) { + Poll::Ready(v) => match v { + Ok(v) => Poll::Ready(v.map_err(Into::into)), + Err(err) => Poll::Ready(Err(Box::new(err) as BoxError)), + }, + Poll::Pending => Poll::Pending, + } + } + } +} + +// The [cfg(not(target_arch = "wasm32"))] above prevent building the tokio::main function +// for wasm32 target, because tokio isn't compatible with wasm32. +// If you aren't building for wasm32, you don't need that line. +// The two lines below avoid the "'main' function not found" error when building for wasm32 target. +#[cfg(any(target_arch = "wasm32"))] +fn main() {} diff --git a/src/async_impl/body.rs b/src/async_impl/body.rs index c2f1257c1..454046dd0 100644 --- a/src/async_impl/body.rs +++ b/src/async_impl/body.rs @@ -148,10 +148,7 @@ impl Body { { use http_body_util::BodyExt; - let boxed = inner - .map_frame(|f| f.map_data(Into::into)) - .map_err(Into::into) - .boxed(); + let boxed = IntoBytesBody { inner }.map_err(Into::into).boxed(); Body { inner: Inner::Streaming(boxed), @@ -461,6 +458,47 @@ where } } +// ===== impl IntoBytesBody ===== + +pin_project! { + struct IntoBytesBody { + #[pin] + inner: B, + } +} + +// We can't use `map_frame()` because that loses the hint data (for good reason). +// But we aren't transforming the data. +impl hyper::body::Body for IntoBytesBody +where + B: hyper::body::Body, + B::Data: Into, +{ + type Data = Bytes; + type Error = B::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll, Self::Error>>> { + match futures_core::ready!(self.project().inner.poll_frame(cx)) { + Some(Ok(f)) => Poll::Ready(Some(Ok(f.map_data(Into::into)))), + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + self.inner.size_hint() + } + + #[inline] + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } +} + #[cfg(test)] mod tests { use http_body::Body as _; @@ -484,8 +522,9 @@ mod tests { assert!(!bytes_body.is_end_stream()); assert_eq!(bytes_body.size_hint().exact(), Some(3)); - let stream_body = Body::wrap(bytes_body); - assert!(!stream_body.is_end_stream()); - assert_eq!(stream_body.size_hint().exact(), None); + // can delegate even when wrapped + let stream_body = Body::wrap(empty_body); + assert!(stream_body.is_end_stream()); + assert_eq!(stream_body.size_hint().exact(), Some(0)); } } diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index f5d07b7b4..71864fb3b 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -1,27 +1,14 @@ #[cfg(any(feature = "native-tls", feature = "__rustls",))] use std::any::Any; +use std::future::Future; use std::net::IpAddr; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use std::time::Duration; use std::{collections::HashMap, convert::TryInto, net::SocketAddr}; use std::{fmt, str}; -use bytes::Bytes; -use http::header::{ - Entry, HeaderMap, HeaderValue, ACCEPT, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, - CONTENT_TYPE, LOCATION, PROXY_AUTHORIZATION, RANGE, REFERER, TRANSFER_ENCODING, USER_AGENT, -}; -use http::uri::Scheme; -use http::Uri; -use hyper_util::client::legacy::connect::HttpConnector; -#[cfg(feature = "default-tls")] -use native_tls_crate::TlsConnector; -use pin_project_lite::pin_project; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; -use tokio::time::Sleep; - use super::decoder::Accepts; use super::request::{Request, RequestBuilder}; use super::response::Response; @@ -30,13 +17,16 @@ use super::Body; use crate::async_impl::h3_client::connect::H3Connector; #[cfg(feature = "http3")] use crate::async_impl::h3_client::{H3Client, H3ResponseFuture}; -use crate::connect::Connector; +use crate::connect::{ + sealed::{Conn, Unnameable}, + BoxedConnectorLayer, BoxedConnectorService, Connector, ConnectorBuilder, +}; #[cfg(feature = "cookies")] use crate::cookie; #[cfg(feature = "hickory-dns")] use crate::dns::hickory::HickoryDnsResolver; use crate::dns::{gai::GaiResolver, DnsResolverWithOverrides, DynResolver, Resolve}; -use crate::error; +use crate::error::{self, BoxError}; use crate::into_url::try_uri; use crate::redirect::{self, remove_sensitive_headers}; #[cfg(feature = "__rustls")] @@ -48,11 +38,25 @@ use crate::Certificate; #[cfg(any(feature = "native-tls", feature = "__rustls"))] use crate::Identity; use crate::{IntoUrl, Method, Proxy, StatusCode, Url}; +use bytes::Bytes; +use http::header::{ + Entry, HeaderMap, HeaderValue, ACCEPT, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, + CONTENT_TYPE, LOCATION, PROXY_AUTHORIZATION, RANGE, REFERER, TRANSFER_ENCODING, USER_AGENT, +}; +use http::uri::Scheme; +use http::Uri; +use hyper_util::client::legacy::connect::HttpConnector; use log::debug; +#[cfg(feature = "default-tls")] +use native_tls_crate::TlsConnector; +use pin_project_lite::pin_project; #[cfg(feature = "http3")] use quinn::TransportConfig; #[cfg(feature = "http3")] use quinn::VarInt; +use tokio::time::Sleep; +use tower::util::BoxCloneSyncServiceLayer; +use tower::{Layer, Service}; type HyperResponseFuture = hyper_util::client::legacy::ResponseFuture; @@ -130,6 +134,7 @@ struct Config { tls_info: bool, #[cfg(feature = "__tls")] tls: TlsBackend, + connector_layers: Vec, http_version_pref: HttpVersionPref, http09_responses: bool, http1_title_case_headers: bool, @@ -185,7 +190,7 @@ impl ClientBuilder { /// Constructs a new `ClientBuilder`. /// /// This is the same as `Client::builder()`. - pub fn new() -> ClientBuilder { + pub fn new() -> Self { let mut headers: HeaderMap = HeaderMap::with_capacity(2); headers.insert(ACCEPT, HeaderValue::from_static("*/*")); @@ -233,6 +238,7 @@ impl ClientBuilder { tls_info: false, #[cfg(feature = "__tls")] tls: TlsBackend::default(), + connector_layers: Vec::new(), http_version_pref: HttpVersionPref::All, http09_responses: false, http1_title_case_headers: false, @@ -278,7 +284,9 @@ impl ClientBuilder { }, } } +} +impl ClientBuilder { /// Returns a `Client` that uses this `ClientBuilder` configuration. /// /// # Errors @@ -302,7 +310,7 @@ impl ClientBuilder { #[cfg(feature = "http3")] let mut h3_connector = None; - let mut connector = { + let mut connector_builder = { #[cfg(feature = "__tls")] fn user_agent(headers: &HeaderMap) -> Option { headers.get(USER_AGENT).cloned() @@ -445,7 +453,7 @@ impl ClientBuilder { tls.max_protocol_version(Some(protocol)); } - Connector::new_default_tls( + ConnectorBuilder::new_default_tls( http, tls, proxies.clone(), @@ -462,7 +470,7 @@ impl ClientBuilder { )? } #[cfg(feature = "native-tls")] - TlsBackend::BuiltNativeTls(conn) => Connector::from_built_default_tls( + TlsBackend::BuiltNativeTls(conn) => ConnectorBuilder::from_built_default_tls( http, conn, proxies.clone(), @@ -489,7 +497,7 @@ impl ClientBuilder { )?; } - Connector::new_rustls_tls( + ConnectorBuilder::new_rustls_tls( http, conn, proxies.clone(), @@ -684,7 +692,7 @@ impl ClientBuilder { )?; } - Connector::new_rustls_tls( + ConnectorBuilder::new_rustls_tls( http, tls, proxies.clone(), @@ -709,7 +717,7 @@ impl ClientBuilder { } #[cfg(not(feature = "__tls"))] - Connector::new( + ConnectorBuilder::new( http, proxies.clone(), config.local_address, @@ -719,8 +727,9 @@ impl ClientBuilder { ) }; - connector.set_timeout(config.connect_timeout); - connector.set_verbose(config.connection_verbose); + connector_builder.set_timeout(config.connect_timeout); + connector_builder.set_verbose(config.connection_verbose); + connector_builder.set_keepalive(config.tcp_keepalive); let mut builder = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()); @@ -763,7 +772,6 @@ impl ClientBuilder { builder.pool_timer(hyper_util::rt::TokioTimer::new()); builder.pool_idle_timeout(config.pool_idle_timeout); builder.pool_max_idle_per_host(config.pool_max_idle_per_host); - connector.set_keepalive(config.tcp_keepalive); if config.http09_responses { builder.http09_responses(true); @@ -801,7 +809,7 @@ impl ClientBuilder { } None => None, }, - hyper: builder.build(connector), + hyper: builder.build(connector_builder.build(config.connector_layers)), headers: config.headers, redirect_policy: config.redirect_policy, referer: config.referer, @@ -1953,6 +1961,43 @@ impl ClientBuilder { self.config.quic_send_window = Some(value); self } + + /// Adds a new Tower [`Layer`](https://docs.rs/tower/latest/tower/trait.Layer.html) to the + /// base connector [`Service`](https://docs.rs/tower/latest/tower/trait.Service.html) which + /// is responsible for connection establishment.a + /// + /// Each subsequent invocation of this function will wrap previous layers. + /// + /// If configured, the `connect_timeout` will be the outermost layer. + /// + /// Example usage: + /// ``` + /// use std::time::Duration; + /// + /// # #[cfg(not(feature = "rustls-tls-no-provider"))] + /// let client = reqwest::Client::builder() + /// // resolved to outermost layer, meaning while we are waiting on concurrency limit + /// .connect_timeout(Duration::from_millis(200)) + /// // underneath the concurrency check, so only after concurrency limit lets us through + /// .connector_layer(tower::timeout::TimeoutLayer::new(Duration::from_millis(50))) + /// .connector_layer(tower::limit::concurrency::ConcurrencyLimitLayer::new(2)) + /// .build() + /// .unwrap(); + /// ``` + /// + pub fn connector_layer(mut self, layer: L) -> ClientBuilder + where + L: Layer + Clone + Send + Sync + 'static, + L::Service: + Service + Clone + Send + Sync + 'static, + >::Future: Send + 'static, + { + let layer = BoxCloneSyncServiceLayer::new(layer); + + self.config.connector_layers.push(layer); + + self + } } type HyperClient = hyper_util::client::legacy::Client; diff --git a/src/async_impl/decoder.rs b/src/async_impl/decoder.rs index d742e6d35..96a27ac45 100644 --- a/src/async_impl/decoder.rs +++ b/src/async_impl/decoder.rs @@ -9,6 +9,14 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; +#[cfg(any( + feature = "gzip", + feature = "zstd", + feature = "brotli", + feature = "deflate" +))] +use futures_util::stream::Fuse; + #[cfg(feature = "gzip")] use async_compression::tokio::bufread::GzipDecoder; @@ -108,19 +116,19 @@ enum Inner { /// A `Gzip` decoder will uncompress the gzipped response content before returning it. #[cfg(feature = "gzip")] - Gzip(Pin, BytesCodec>>>), + Gzip(Pin, BytesCodec>>>>), /// A `Brotli` decoder will uncompress the brotlied response content before returning it. #[cfg(feature = "brotli")] - Brotli(Pin, BytesCodec>>>), + Brotli(Pin, BytesCodec>>>>), /// A `Zstd` decoder will uncompress the zstd compressed response content before returning it. #[cfg(feature = "zstd")] - Zstd(Pin, BytesCodec>>>), + Zstd(Pin, BytesCodec>>>>), /// A `Deflate` decoder will uncompress the deflated response content before returning it. #[cfg(feature = "deflate")] - Deflate(Pin, BytesCodec>>>), + Deflate(Pin, BytesCodec>>>>), /// A decoder that doesn't have a value yet. #[cfg(any( @@ -365,34 +373,74 @@ impl HttpBody for Decoder { } #[cfg(feature = "gzip")] Inner::Gzip(ref mut decoder) => { - match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) { Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), - None => Poll::Ready(None), + None => { + // poll inner connection until EOF after gzip stream is finished + let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut(); + match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) { + Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode( + "there are extra bytes after body has been decompressed", + )))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } } } #[cfg(feature = "brotli")] Inner::Brotli(ref mut decoder) => { - match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) { Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), - None => Poll::Ready(None), + None => { + // poll inner connection until EOF after brotli stream is finished + let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut(); + match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) { + Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode( + "there are extra bytes after body has been decompressed", + )))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } } } #[cfg(feature = "zstd")] Inner::Zstd(ref mut decoder) => { - match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) { Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), - None => Poll::Ready(None), + None => { + // poll inner connection until EOF after zstd stream is finished + let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut(); + match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) { + Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode( + "there are extra bytes after body has been decompressed", + )))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } } } #[cfg(feature = "deflate")] Inner::Deflate(ref mut decoder) => { - match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) { Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), - None => Poll::Ready(None), + None => { + // poll inner connection until EOF after deflate stream is finished + let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut(); + match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) { + Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode( + "there are extra bytes after body has been decompressed", + )))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } } } } @@ -456,25 +504,37 @@ impl Future for Pending { match self.1 { #[cfg(feature = "brotli")] - DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(Box::pin(FramedRead::new( - BrotliDecoder::new(StreamReader::new(_body)), - BytesCodec::new(), - ))))), + DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(Box::pin( + FramedRead::new( + BrotliDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ) + .fuse(), + )))), #[cfg(feature = "zstd")] - DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin(FramedRead::new( - ZstdDecoder::new(StreamReader::new(_body)), - BytesCodec::new(), - ))))), + DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin( + FramedRead::new( + ZstdDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ) + .fuse(), + )))), #[cfg(feature = "gzip")] - DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin(FramedRead::new( - GzipDecoder::new(StreamReader::new(_body)), - BytesCodec::new(), - ))))), + DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin( + FramedRead::new( + GzipDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ) + .fuse(), + )))), #[cfg(feature = "deflate")] - DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(Box::pin(FramedRead::new( - ZlibDecoder::new(StreamReader::new(_body)), - BytesCodec::new(), - ))))), + DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(Box::pin( + FramedRead::new( + ZlibDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ) + .fuse(), + )))), } } } diff --git a/src/blocking/client.rs b/src/blocking/client.rs index 4d1d07d7d..f3e8f6e1c 100644 --- a/src/blocking/client.rs +++ b/src/blocking/client.rs @@ -12,11 +12,16 @@ use std::time::Duration; use http::header::HeaderValue; use log::{error, trace}; use tokio::sync::{mpsc, oneshot}; +use tower::Layer; +use tower::Service; use super::request::{Request, RequestBuilder}; use super::response::Response; use super::wait; +use crate::connect::sealed::{Conn, Unnameable}; +use crate::connect::BoxedConnectorService; use crate::dns::Resolve; +use crate::error::BoxError; #[cfg(feature = "__tls")] use crate::tls; #[cfg(feature = "__rustls")] @@ -84,13 +89,15 @@ impl ClientBuilder { /// Constructs a new `ClientBuilder`. /// /// This is the same as `Client::builder()`. - pub fn new() -> ClientBuilder { + pub fn new() -> Self { ClientBuilder { inner: async_impl::ClientBuilder::new(), timeout: Timeout::default(), } } +} +impl ClientBuilder { /// Returns a `Client` that uses this `ClientBuilder` configuration. /// /// # Errors @@ -968,6 +975,35 @@ impl ClientBuilder { self.with_inner(|inner| inner.dns_resolver(resolver)) } + /// Adds a new Tower [`Layer`](https://docs.rs/tower/latest/tower/trait.Layer.html) to the + /// base connector [`Service`](https://docs.rs/tower/latest/tower/trait.Service.html) which + /// is responsible for connection establishment. + /// + /// Each subsequent invocation of this function will wrap previous layers. + /// + /// Example usage: + /// ``` + /// use std::time::Duration; + /// + /// let client = reqwest::blocking::Client::builder() + /// // resolved to outermost layer, meaning while we are waiting on concurrency limit + /// .connect_timeout(Duration::from_millis(200)) + /// // underneath the concurrency check, so only after concurrency limit lets us through + /// .connector_layer(tower::timeout::TimeoutLayer::new(Duration::from_millis(50))) + /// .connector_layer(tower::limit::concurrency::ConcurrencyLimitLayer::new(2)) + /// .build() + /// .unwrap(); + /// ``` + pub fn connector_layer(self, layer: L) -> ClientBuilder + where + L: Layer + Clone + Send + Sync + 'static, + L::Service: + Service + Clone + Send + Sync + 'static, + >::Future: Send + 'static, + { + self.with_inner(|inner| inner.connector_layer(layer)) + } + // private fn with_inner(mut self, func: F) -> ClientBuilder diff --git a/src/connect.rs b/src/connect.rs index ff86ba3c9..dfaf028a9 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -8,9 +8,11 @@ use hyper_util::client::legacy::connect::{Connected, Connection}; use hyper_util::rt::TokioIo; #[cfg(feature = "default-tls")] use native_tls_crate::{TlsConnector, TlsConnectorBuilder}; +use pin_project_lite::pin_project; +use tower::util::{BoxCloneSyncServiceLayer, MapRequestLayer}; +use tower::{timeout::TimeoutLayer, util::BoxCloneSyncService, ServiceBuilder}; use tower_service::Service; -use pin_project_lite::pin_project; use std::future::Future; use std::io::{self, IoSlice}; use std::net::IpAddr; @@ -24,13 +26,47 @@ use self::native_tls_conn::NativeTlsConn; #[cfg(feature = "__rustls")] use self::rustls_tls_conn::RustlsTlsConn; use crate::dns::DynResolver; -use crate::error::BoxError; +use crate::error::{cast_to_internal_error, BoxError}; use crate::proxy::{Proxy, ProxyScheme}; +use sealed::{Conn, Unnameable}; pub(crate) type HttpConnector = hyper_util::client::legacy::connect::HttpConnector; #[derive(Clone)] -pub(crate) struct Connector { +pub(crate) enum Connector { + // base service, with or without an embedded timeout + Simple(ConnectorService), + // at least one custom layer along with maybe an outer timeout layer + // from `builder.connect_timeout()` + WithLayers(BoxCloneSyncService), +} + +impl Service for Connector { + type Response = Conn; + type Error = BoxError; + type Future = Connecting; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match self { + Connector::Simple(service) => service.poll_ready(cx), + Connector::WithLayers(service) => service.poll_ready(cx), + } + } + + fn call(&mut self, dst: Uri) -> Self::Future { + match self { + Connector::Simple(service) => service.call(dst), + Connector::WithLayers(service) => service.call(Unnameable(dst)), + } + } +} + +pub(crate) type BoxedConnectorService = BoxCloneSyncService; + +pub(crate) type BoxedConnectorLayer = + BoxCloneSyncServiceLayer; + +pub(crate) struct ConnectorBuilder { inner: Inner, proxies: Arc>, verbose: verbose::Wrapper, @@ -43,21 +79,70 @@ pub(crate) struct Connector { user_agent: Option, } -#[derive(Clone)] -enum Inner { - #[cfg(not(feature = "__tls"))] - Http(HttpConnector), - #[cfg(feature = "default-tls")] - DefaultTls(HttpConnector, TlsConnector), - #[cfg(feature = "__rustls")] - RustlsTls { - http: HttpConnector, - tls: Arc, - tls_proxy: Arc, - }, -} +impl ConnectorBuilder { + pub(crate) fn build(self, layers: Vec) -> Connector +where { + // construct the inner tower service + let mut base_service = ConnectorService { + inner: self.inner, + proxies: self.proxies, + verbose: self.verbose, + #[cfg(feature = "__tls")] + nodelay: self.nodelay, + #[cfg(feature = "__tls")] + tls_info: self.tls_info, + #[cfg(feature = "__tls")] + user_agent: self.user_agent, + simple_timeout: None, + }; + + if layers.is_empty() { + // we have no user-provided layers, only use concrete types + base_service.simple_timeout = self.timeout; + return Connector::Simple(base_service); + } + + // otherwise we have user provided layers + // so we need type erasure all the way through + // as well as mapping the unnameable type of the layers back to Uri for the inner service + let unnameable_service = ServiceBuilder::new() + .layer(MapRequestLayer::new(|request: Unnameable| request.0)) + .service(base_service); + let mut service = BoxCloneSyncService::new(unnameable_service); + + for layer in layers { + service = ServiceBuilder::new().layer(layer).service(service); + } + + // now we handle the concrete stuff - any `connect_timeout`, + // plus a final map_err layer we can use to cast default tower layer + // errors to internal errors + match self.timeout { + Some(timeout) => { + let service = ServiceBuilder::new() + .layer(TimeoutLayer::new(timeout)) + .service(service); + let service = ServiceBuilder::new() + .map_err(|error: BoxError| cast_to_internal_error(error)) + .service(service); + let service = BoxCloneSyncService::new(service); + + Connector::WithLayers(service) + } + None => { + // no timeout, but still map err + // no named timeout layer but we still map errors since + // we might have user-provided timeout layer + let service = ServiceBuilder::new().service(service); + let service = ServiceBuilder::new() + .map_err(|error: BoxError| cast_to_internal_error(error)) + .service(service); + let service = BoxCloneSyncService::new(service); + Connector::WithLayers(service) + } + } + } -impl Connector { #[cfg(not(feature = "__tls"))] pub(crate) fn new( mut http: HttpConnector, @@ -66,7 +151,7 @@ impl Connector { #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] interface: Option<&str>, nodelay: bool, - ) -> Connector + ) -> ConnectorBuilder where T: Into>, { @@ -77,10 +162,10 @@ impl Connector { } http.set_nodelay(nodelay); - Connector { + ConnectorBuilder { inner: Inner::Http(http), - verbose: verbose::OFF, proxies, + verbose: verbose::OFF, timeout: None, } } @@ -96,7 +181,7 @@ impl Connector { interface: Option<&str>, nodelay: bool, tls_info: bool, - ) -> crate::Result + ) -> crate::Result where T: Into>, { @@ -125,7 +210,7 @@ impl Connector { interface: Option<&str>, nodelay: bool, tls_info: bool, - ) -> Connector + ) -> ConnectorBuilder where T: Into>, { @@ -137,14 +222,14 @@ impl Connector { http.set_nodelay(nodelay); http.enforce_http(false); - Connector { + ConnectorBuilder { inner: Inner::DefaultTls(http, tls), proxies, verbose: verbose::OFF, - timeout: None, nodelay, tls_info, user_agent, + timeout: None, } } @@ -159,7 +244,7 @@ impl Connector { interface: Option<&str>, nodelay: bool, tls_info: bool, - ) -> Connector + ) -> ConnectorBuilder where T: Into>, { @@ -180,7 +265,7 @@ impl Connector { (Arc::new(tls), Arc::new(tls_proxy)) }; - Connector { + ConnectorBuilder { inner: Inner::RustlsTls { http, tls, @@ -188,10 +273,10 @@ impl Connector { }, proxies, verbose: verbose::OFF, - timeout: None, nodelay, tls_info, user_agent, + timeout: None, } } @@ -203,6 +288,52 @@ impl Connector { self.verbose.0 = enabled; } + pub(crate) fn set_keepalive(&mut self, dur: Option) { + match &mut self.inner { + #[cfg(feature = "default-tls")] + Inner::DefaultTls(http, _tls) => http.set_keepalive(dur), + #[cfg(feature = "__rustls")] + Inner::RustlsTls { http, .. } => http.set_keepalive(dur), + #[cfg(not(feature = "__tls"))] + Inner::Http(http) => http.set_keepalive(dur), + } + } +} + +#[allow(missing_debug_implementations)] +#[derive(Clone)] +pub(crate) struct ConnectorService { + inner: Inner, + proxies: Arc>, + verbose: verbose::Wrapper, + /// When there is a single timeout layer and no other layers, + /// we embed it directly inside our base Service::call(). + /// This lets us avoid an extra `Box::pin` indirection layer + /// since `tokio::time::Timeout` is `Unpin` + simple_timeout: Option, + #[cfg(feature = "__tls")] + nodelay: bool, + #[cfg(feature = "__tls")] + tls_info: bool, + #[cfg(feature = "__tls")] + user_agent: Option, +} + +#[derive(Clone)] +enum Inner { + #[cfg(not(feature = "__tls"))] + Http(HttpConnector), + #[cfg(feature = "default-tls")] + DefaultTls(HttpConnector, TlsConnector), + #[cfg(feature = "__rustls")] + RustlsTls { + http: HttpConnector, + tls: Arc, + tls_proxy: Arc, + }, +} + +impl ConnectorService { #[cfg(feature = "socks")] async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result { let dns = match proxy { @@ -449,17 +580,6 @@ impl Connector { self.connect_with_maybe_proxy(proxy_dst, true).await } - - pub fn set_keepalive(&mut self, dur: Option) { - match &mut self.inner { - #[cfg(feature = "default-tls")] - Inner::DefaultTls(http, _tls) => http.set_keepalive(dur), - #[cfg(feature = "__rustls")] - Inner::RustlsTls { http, .. } => http.set_keepalive(dur), - #[cfg(not(feature = "__tls"))] - Inner::Http(http) => http.set_keepalive(dur), - } - } } fn into_uri(scheme: Scheme, host: Authority) -> Uri { @@ -487,7 +607,7 @@ where } } -impl Service for Connector { +impl Service for ConnectorService { type Response = Conn; type Error = BoxError; type Future = Connecting; @@ -498,7 +618,7 @@ impl Service for Connector { fn call(&mut self, dst: Uri) -> Self::Future { log::debug!("starting new connection: {dst:?}"); - let timeout = self.timeout; + let timeout = self.simple_timeout; for prox in self.proxies.iter() { if let Some(proxy_scheme) = prox.intercept(&dst) { return Box::pin(with_timeout( @@ -633,80 +753,87 @@ impl AsyncConnWithInfo for T {} type BoxConn = Box; -pin_project! { - /// Note: the `is_proxy` member means *is plain text HTTP proxy*. - /// This tells hyper whether the URI should be written in - /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or - /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise. - pub(crate) struct Conn { - #[pin] - inner: BoxConn, - is_proxy: bool, - // Only needed for __tls, but #[cfg()] on fields breaks pin_project! - tls_info: bool, +pub(crate) mod sealed { + use super::*; + #[derive(Debug)] + pub struct Unnameable(pub(super) Uri); + + pin_project! { + /// Note: the `is_proxy` member means *is plain text HTTP proxy*. + /// This tells hyper whether the URI should be written in + /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or + /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise. + #[allow(missing_debug_implementations)] + pub struct Conn { + #[pin] + pub(super)inner: BoxConn, + pub(super) is_proxy: bool, + // Only needed for __tls, but #[cfg()] on fields breaks pin_project! + pub(super) tls_info: bool, + } } -} -impl Connection for Conn { - fn connected(&self) -> Connected { - let connected = self.inner.connected().proxy(self.is_proxy); - #[cfg(feature = "__tls")] - if self.tls_info { - if let Some(tls_info) = self.inner.tls_info() { - connected.extra(tls_info) + impl Connection for Conn { + fn connected(&self) -> Connected { + let connected = self.inner.connected().proxy(self.is_proxy); + #[cfg(feature = "__tls")] + if self.tls_info { + if let Some(tls_info) = self.inner.tls_info() { + connected.extra(tls_info) + } else { + connected + } } else { connected } - } else { + #[cfg(not(feature = "__tls"))] connected } - #[cfg(not(feature = "__tls"))] - connected } -} -impl Read for Conn { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context, - buf: ReadBufCursor<'_>, - ) -> Poll> { - let this = self.project(); - Read::poll_read(this.inner, cx, buf) + impl Read for Conn { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: ReadBufCursor<'_>, + ) -> Poll> { + let this = self.project(); + Read::poll_read(this.inner, cx, buf) + } } -} -impl Write for Conn { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &[u8], - ) -> Poll> { - let this = self.project(); - Write::poll_write(this.inner, cx, buf) - } + impl Write for Conn { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + let this = self.project(); + Write::poll_write(this.inner, cx, buf) + } - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - let this = self.project(); - Write::poll_write_vectored(this.inner, cx, bufs) - } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + let this = self.project(); + Write::poll_write_vectored(this.inner, cx, bufs) + } - fn is_write_vectored(&self) -> bool { - self.inner.is_write_vectored() - } + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let this = self.project(); - Write::poll_flush(this.inner, cx) - } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.project(); + Write::poll_flush(this.inner, cx) + } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let this = self.project(); - Write::poll_shutdown(this.inner, cx) + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.project(); + Write::poll_shutdown(this.inner, cx) + } } } diff --git a/src/error.rs b/src/error.rs index ca7413fd6..6a9f07e51 100644 --- a/src/error.rs +++ b/src/error.rs @@ -165,6 +165,18 @@ impl Error { } } +/// Converts from external types to reqwest's +/// internal equivalents. +/// +/// Currently only is used for `tower::timeout::error::Elapsed`. +pub(crate) fn cast_to_internal_error(error: BoxError) -> BoxError { + if error.is::() { + Box::new(crate::error::TimedOut) as BoxError + } else { + error + } +} + impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let mut builder = f.debug_struct("reqwest::Error"); diff --git a/tests/brotli.rs b/tests/brotli.rs index 5c2b01849..ba116ed92 100644 --- a/tests/brotli.rs +++ b/tests/brotli.rs @@ -1,6 +1,7 @@ mod support; use std::io::Read; use support::server; +use tokio::io::AsyncWriteExt; #[tokio::test] async fn brotli_response() { @@ -145,3 +146,212 @@ async fn brotli_case(response_size: usize, chunk_size: usize) { let body = res.text().await.expect("text"); assert_eq!(body, content); } + +const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\ + Content-Type: text/plain\x0d\x0a\ + Connection: keep-alive\x0d\x0a\ + Content-Encoding: br\x0d\x0a"; + +const RESPONSE_CONTENT: &str = "some message here"; + +fn brotli_compress(input: &[u8]) -> Vec { + let mut encoder = brotli_crate::CompressorReader::new(input, 4096, 5, 20); + let mut brotlied_content = Vec::new(); + encoder.read_to_end(&mut brotlied_content).unwrap(); + brotlied_content +} + +#[tokio::test] +async fn test_non_chunked_non_fragmented_response() { + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes()); + let content_length_header = + format!("Content-Length: {}\r\n\r\n", brotlied_content.len()).into_bytes(); + let response = [ + COMPRESSED_RESPONSE_HEADERS, + &content_length_header, + &brotlied_content, + ] + .concat(); + + client_socket + .write_all(response.as_slice()) + .await + .expect("response write_all failed"); + client_socket.flush().await.expect("response flush failed"); + }) + }); + + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_1() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + brotlied_content.len() + ) + .as_bytes(), + &brotlied_content, + ] + .concat(); + let response_second_part = b"\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_2() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + brotlied_content.len() + ) + .as_bytes(), + &brotlied_content, + b"\r\n", + ] + .concat(); + let response_second_part = b"0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_with_extra_bytes() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + brotlied_content.len() + ) + .as_bytes(), + &brotlied_content, + ] + .concat(); + let response_second_part = b"\r\n2ab\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + let err = res.text().await.expect_err("there must be an error"); + assert!(err.is_decode()); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} diff --git a/tests/client.rs b/tests/client.rs index 51fb9dfa0..f99418322 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -197,7 +197,7 @@ async fn body_pipe_response() { http::Response::new("pipe me".into()) } else { assert_eq!(req.uri(), "/pipe"); - assert_eq!(req.headers()["transfer-encoding"], "chunked"); + assert_eq!(req.headers()["content-length"], "7"); let full: Vec = req .into_body() diff --git a/tests/connector_layers.rs b/tests/connector_layers.rs new file mode 100644 index 000000000..1be18aeb8 --- /dev/null +++ b/tests/connector_layers.rs @@ -0,0 +1,374 @@ +#![cfg(not(target_arch = "wasm32"))] +#![cfg(not(feature = "rustls-tls-manual-roots-no-provider"))] +mod support; + +use std::time::Duration; + +use futures_util::future::join_all; +use tower::layer::util::Identity; +use tower::limit::ConcurrencyLimitLayer; +use tower::timeout::TimeoutLayer; + +use support::{delay_layer::DelayLayer, server}; + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn non_op_layer() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(Identity::new()) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + assert!(res.is_ok()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn non_op_layer_with_timeout() { + let _ = env_logger::try_init(); + + let client = reqwest::Client::builder() + .connector_layer(Identity::new()) + .connect_timeout(Duration::from_millis(200)) + .no_proxy() + .build() + .unwrap(); + + // never returns + let url = "http://192.0.2.1:81/slow"; + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_connect_timeout_layer_never_returning() { + let _ = env_logger::try_init(); + + let client = reqwest::Client::builder() + .connector_layer(TimeoutLayer::new(Duration::from_millis(100))) + .no_proxy() + .build() + .unwrap(); + + // never returns + let url = "http://192.0.2.1:81/slow"; + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_connect_timeout_layer_slow() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(200))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(100))) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn multiple_timeout_layers_under_threshold() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(200))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(300))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(500))) + .connect_timeout(Duration::from_millis(200)) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + assert!(res.is_ok()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn multiple_timeout_layers_over_threshold() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .connect_timeout(Duration::from_millis(50)) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_concurrency_limit_layer_timeout() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(200)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .no_proxy() + .build() + .unwrap(); + + // first call succeeds since no resource contention + let res = client.get(url.clone()).send().await; + assert!(res.is_ok()); + + // 3 calls where the second two wait on the first and time out + let mut futures = Vec::new(); + for _ in 0..3 { + futures.push(client.clone().get(url.clone()).send()); + } + + let all_res = join_all(futures).await; + + let timed_out = all_res + .into_iter() + .any(|res| res.is_err_and(|err| err.is_timeout())); + + assert!(timed_out, "at least one request should have timed out"); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_concurrency_limit_layer_success() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(200))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(1000)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .no_proxy() + .build() + .unwrap(); + + // first call succeeds since no resource contention + let res = client.get(url.clone()).send().await; + assert!(res.is_ok()); + + // 3 calls of which all are individually below the inner timeout + // and the sum is below outer timeout which affects the final call which waited the whole time + let mut futures = Vec::new(); + for _ in 0..3 { + futures.push(client.clone().get(url.clone()).send()); + } + + let all_res = join_all(futures).await; + + for res in all_res.into_iter() { + assert!( + res.is_ok(), + "neither outer long timeout or inner short timeout should be exceeded" + ); + } +} + +#[cfg(feature = "blocking")] +#[test] +fn non_op_layer_blocking_client() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(Identity::new()) + .build() + .unwrap(); + + let res = client.get(url).send(); + + assert!(res.is_ok()); +} + +#[cfg(feature = "blocking")] +#[test] +fn timeout_layer_blocking_client() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send(); + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(feature = "blocking")] +#[test] +fn concurrency_layer_blocking_client_timeout() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(200)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .build() + .unwrap(); + + let res = client.get(url.clone()).send(); + + assert!(res.is_ok()); + + // 3 calls where the second two wait on the first and time out + let mut join_handles = Vec::new(); + for _ in 0..3 { + let client = client.clone(); + let url = url.clone(); + let join_handle = std::thread::spawn(move || client.get(url.clone()).send()); + join_handles.push(join_handle); + } + + let timed_out = join_handles + .into_iter() + .any(|handle| handle.join().unwrap().is_err_and(|err| err.is_timeout())); + + assert!(timed_out, "at least one request should have timed out"); +} + +#[cfg(feature = "blocking")] +#[test] +fn concurrency_layer_blocking_client_success() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(200))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(1000)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .build() + .unwrap(); + + let res = client.get(url.clone()).send(); + + assert!(res.is_ok()); + + // 3 calls of which all are individually below the inner timeout + // and the sum is below outer timeout which affects the final call which waited the whole time + let mut join_handles = Vec::new(); + for _ in 0..3 { + let client = client.clone(); + let url = url.clone(); + let join_handle = std::thread::spawn(move || client.get(url.clone()).send()); + join_handles.push(join_handle); + } + + for handle in join_handles { + let res = handle.join().unwrap(); + assert!( + res.is_ok(), + "neither outer long timeout or inner short timeout should be exceeded" + ); + } +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn no_generic_bounds_required_for_client_new() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::new(); + let res = client.get(url).send().await; + + assert!(res.is_ok()); +} + +#[cfg(feature = "blocking")] +#[test] +fn no_generic_bounds_required_for_client_new_blocking() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::new(); + let res = client.get(url).send(); + + assert!(res.is_ok()); +} diff --git a/tests/deflate.rs b/tests/deflate.rs index ec27ba180..55331afc5 100644 --- a/tests/deflate.rs +++ b/tests/deflate.rs @@ -1,6 +1,7 @@ mod support; use std::io::Write; use support::server; +use tokio::io::AsyncWriteExt; #[tokio::test] async fn deflate_response() { @@ -148,3 +149,214 @@ async fn deflate_case(response_size: usize, chunk_size: usize) { let body = res.text().await.expect("text"); assert_eq!(body, content); } + +const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\ + Content-Type: text/plain\x0d\x0a\ + Connection: keep-alive\x0d\x0a\ + Content-Encoding: deflate\x0d\x0a"; + +const RESPONSE_CONTENT: &str = "some message here"; + +fn deflate_compress(input: &[u8]) -> Vec { + let mut encoder = libflate::zlib::Encoder::new(Vec::new()).unwrap(); + match encoder.write(input) { + Ok(n) => assert!(n > 0, "Failed to write to encoder."), + _ => panic!("Failed to deflate encode string."), + }; + encoder.finish().into_result().unwrap() +} + +#[tokio::test] +async fn test_non_chunked_non_fragmented_response() { + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let deflated_content = deflate_compress(RESPONSE_CONTENT.as_bytes()); + let content_length_header = + format!("Content-Length: {}\r\n\r\n", deflated_content.len()).into_bytes(); + let response = [ + COMPRESSED_RESPONSE_HEADERS, + &content_length_header, + &deflated_content, + ] + .concat(); + + client_socket + .write_all(response.as_slice()) + .await + .expect("response write_all failed"); + client_socket.flush().await.expect("response flush failed"); + }) + }); + + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_1() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let deflated_content = deflate_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + deflated_content.len() + ) + .as_bytes(), + &deflated_content, + ] + .concat(); + let response_second_part = b"\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_2() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let deflated_content = deflate_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + deflated_content.len() + ) + .as_bytes(), + &deflated_content, + b"\r\n", + ] + .concat(); + let response_second_part = b"0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_with_extra_bytes() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let deflated_content = deflate_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + deflated_content.len() + ) + .as_bytes(), + &deflated_content, + ] + .concat(); + let response_second_part = b"\r\n2ab\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + let err = res.text().await.expect_err("there must be an error"); + assert!(err.is_decode()); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} diff --git a/tests/gzip.rs b/tests/gzip.rs index 57189e0ac..74ead8783 100644 --- a/tests/gzip.rs +++ b/tests/gzip.rs @@ -2,6 +2,8 @@ mod support; use support::server; use std::io::Write; +use tokio::io::AsyncWriteExt; +use tokio::time::Duration; #[tokio::test] async fn gzip_response() { @@ -149,3 +151,214 @@ async fn gzip_case(response_size: usize, chunk_size: usize) { let body = res.text().await.expect("text"); assert_eq!(body, content); } + +const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\ + Content-Type: text/plain\x0d\x0a\ + Connection: keep-alive\x0d\x0a\ + Content-Encoding: gzip\x0d\x0a"; + +const RESPONSE_CONTENT: &str = "some message here"; + +fn gzip_compress(input: &[u8]) -> Vec { + let mut encoder = libflate::gzip::Encoder::new(Vec::new()).unwrap(); + match encoder.write(input) { + Ok(n) => assert!(n > 0, "Failed to write to encoder."), + _ => panic!("Failed to gzip encode string."), + }; + encoder.finish().into_result().unwrap() +} + +#[tokio::test] +async fn test_non_chunked_non_fragmented_response() { + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let gzipped_content = gzip_compress(RESPONSE_CONTENT.as_bytes()); + let content_length_header = + format!("Content-Length: {}\r\n\r\n", gzipped_content.len()).into_bytes(); + let response = [ + COMPRESSED_RESPONSE_HEADERS, + &content_length_header, + &gzipped_content, + ] + .concat(); + + client_socket + .write_all(response.as_slice()) + .await + .expect("response write_all failed"); + client_socket.flush().await.expect("response flush failed"); + }) + }); + + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_1() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let gzipped_content = gzip_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + gzipped_content.len() + ) + .as_bytes(), + &gzipped_content, + ] + .concat(); + let response_second_part = b"\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_2() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let gzipped_content = gzip_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + gzipped_content.len() + ) + .as_bytes(), + &gzipped_content, + b"\r\n", + ] + .concat(); + let response_second_part = b"0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_with_extra_bytes() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let gzipped_content = gzip_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + gzipped_content.len() + ) + .as_bytes(), + &gzipped_content, + ] + .concat(); + let response_second_part = b"\r\n2ab\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + let err = res.text().await.expect_err("there must be an error"); + assert!(err.is_decode()); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} diff --git a/tests/support/delay_layer.rs b/tests/support/delay_layer.rs new file mode 100644 index 000000000..b8eec42a1 --- /dev/null +++ b/tests/support/delay_layer.rs @@ -0,0 +1,119 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use pin_project_lite::pin_project; +use tokio::time::Sleep; +use tower::{BoxError, Layer, Service}; + +/// This tower layer injects an arbitrary delay before calling downstream layers. +#[derive(Clone)] +pub struct DelayLayer { + delay: Duration, +} + +impl DelayLayer { + pub const fn new(delay: Duration) -> Self { + DelayLayer { delay } + } +} + +impl Layer for DelayLayer { + type Service = Delay; + fn layer(&self, service: S) -> Self::Service { + Delay::new(service, self.delay) + } +} + +impl std::fmt::Debug for DelayLayer { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("DelayLayer") + .field("delay", &self.delay) + .finish() + } +} + +/// This tower service injects an arbitrary delay before calling downstream layers. +#[derive(Debug, Clone)] +pub struct Delay { + inner: S, + delay: Duration, +} +impl Delay { + pub fn new(inner: S, delay: Duration) -> Self { + Delay { inner, delay } + } +} + +impl Service for Delay +where + S: Service, + S::Error: Into, +{ + type Response = S::Response; + + type Error = BoxError; + + type Future = ResponseFuture; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.inner.poll_ready(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)), + } + } + + fn call(&mut self, req: Request) -> Self::Future { + let response = self.inner.call(req); + let sleep = tokio::time::sleep(self.delay); + + ResponseFuture::new(response, sleep) + } +} + +// `Delay` response future +pin_project! { + #[derive(Debug)] + pub struct ResponseFuture { + #[pin] + response: S, + #[pin] + sleep: Sleep, + } +} + +impl ResponseFuture { + pub(crate) fn new(response: S, sleep: Sleep) -> Self { + ResponseFuture { response, sleep } + } +} + +impl Future for ResponseFuture +where + F: Future>, + E: Into, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + // First poll the sleep until complete + match this.sleep.poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(_) => {} + } + + // Then poll the inner future + match this.response.poll(cx) { + Poll::Ready(v) => Poll::Ready(v.map_err(Into::into)), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/tests/support/mod.rs b/tests/support/mod.rs index c796956d8..9d4ce7b9b 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -1,3 +1,4 @@ +pub mod delay_layer; pub mod delay_server; pub mod server; diff --git a/tests/support/server.rs b/tests/support/server.rs index 29835ead1..79ebd2d8f 100644 --- a/tests/support/server.rs +++ b/tests/support/server.rs @@ -6,6 +6,8 @@ use std::sync::mpsc as std_mpsc; use std::thread; use std::time::Duration; +use tokio::io::AsyncReadExt; +use tokio::net::TcpStream; use tokio::runtime; use tokio::sync::oneshot; @@ -240,3 +242,104 @@ where .join() .unwrap() } + +pub fn low_level_with_response(do_response: F) -> Server +where + for<'c> F: Fn(&'c [u8], &'c mut TcpStream) -> Box + Send + 'c> + + Clone + + Send + + 'static, +{ + // Spawn new runtime in thread to prevent reactor execution context conflict + let test_name = thread::current().name().unwrap_or("").to_string(); + thread::spawn(move || { + let rt = runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("new rt"); + let listener = rt.block_on(async move { + tokio::net::TcpListener::bind(&std::net::SocketAddr::from(([127, 0, 0, 1], 0))) + .await + .unwrap() + }); + let addr = listener.local_addr().unwrap(); + + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + let (panic_tx, panic_rx) = std_mpsc::channel(); + let (events_tx, events_rx) = std_mpsc::channel(); + let tname = format!("test({})-support-server", test_name,); + thread::Builder::new() + .name(tname) + .spawn(move || { + rt.block_on(async move { + loop { + tokio::select! { + _ = &mut shutdown_rx => { + break; + } + accepted = listener.accept() => { + let (io, _) = accepted.expect("accepted"); + let do_response = do_response.clone(); + let events_tx = events_tx.clone(); + tokio::spawn(async move { + low_level_server_client(io, do_response).await; + let _ = events_tx.send(Event::ConnectionClosed); + }); + } + } + } + let _ = panic_tx.send(()); + }); + }) + .expect("thread spawn"); + Server { + addr, + panic_rx, + events_rx, + shutdown_tx: Some(shutdown_tx), + } + }) + .join() + .unwrap() +} + +async fn low_level_server_client(mut client_socket: TcpStream, do_response: F) +where + for<'c> F: Fn(&'c [u8], &'c mut TcpStream) -> Box + Send + 'c>, +{ + loop { + let request = low_level_read_http_request(&mut client_socket) + .await + .expect("read_http_request failed"); + if request.is_empty() { + // connection closed by client + break; + } + + Box::into_pin(do_response(&request, &mut client_socket)).await; + } +} + +async fn low_level_read_http_request( + client_socket: &mut TcpStream, +) -> core::result::Result, std::io::Error> { + let mut buf = Vec::new(); + + // Read until the delimiter "\r\n\r\n" is found + loop { + let mut temp_buffer = [0; 1024]; + let n = client_socket.read(&mut temp_buffer).await?; + + if n == 0 { + break; + } + + buf.extend_from_slice(&temp_buffer[..n]); + + if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") { + return Ok(buf.drain(..pos + 4).collect()); + } + } + + Ok(buf) +} diff --git a/tests/timeouts.rs b/tests/timeouts.rs index 79a6fbb4d..71dc0ce66 100644 --- a/tests/timeouts.rs +++ b/tests/timeouts.rs @@ -337,6 +337,24 @@ fn timeout_blocking_request() { assert_eq!(err.url().map(|u| u.as_str()), Some(url.as_str())); } +#[cfg(feature = "blocking")] +#[test] +fn connect_timeout_blocking_request() { + let _ = env_logger::try_init(); + + let client = reqwest::blocking::Client::builder() + .connect_timeout(Duration::from_millis(100)) + .build() + .unwrap(); + + // never returns + let url = "http://192.0.2.1:81/slow"; + + let err = client.get(url).send().unwrap_err(); + + assert!(err.is_timeout()); +} + #[cfg(feature = "blocking")] #[cfg(feature = "stream")] #[test] diff --git a/tests/zstd.rs b/tests/zstd.rs index d1886ee49..ed3914e79 100644 --- a/tests/zstd.rs +++ b/tests/zstd.rs @@ -1,5 +1,6 @@ mod support; use support::server; +use tokio::io::AsyncWriteExt; #[tokio::test] async fn zstd_response() { @@ -142,3 +143,209 @@ async fn zstd_case(response_size: usize, chunk_size: usize) { let body = res.text().await.expect("text"); assert_eq!(body, content); } + +const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\ + Content-Type: text/plain\x0d\x0a\ + Connection: keep-alive\x0d\x0a\ + Content-Encoding: zstd\x0d\x0a"; + +const RESPONSE_CONTENT: &str = "some message here"; + +fn zstd_compress(input: &[u8]) -> Vec { + zstd_crate::encode_all(input, 3).unwrap() +} + +#[tokio::test] +async fn test_non_chunked_non_fragmented_response() { + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let zstded_content = zstd_compress(RESPONSE_CONTENT.as_bytes()); + let content_length_header = + format!("Content-Length: {}\r\n\r\n", zstded_content.len()).into_bytes(); + let response = [ + COMPRESSED_RESPONSE_HEADERS, + &content_length_header, + &zstded_content, + ] + .concat(); + + client_socket + .write_all(response.as_slice()) + .await + .expect("response write_all failed"); + client_socket.flush().await.expect("response flush failed"); + }) + }); + + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_1() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let zstded_content = zstd_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + zstded_content.len() + ) + .as_bytes(), + &zstded_content, + ] + .concat(); + let response_second_part = b"\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_2() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let zstded_content = zstd_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + zstded_content.len() + ) + .as_bytes(), + &zstded_content, + b"\r\n", + ] + .concat(); + let response_second_part = b"0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_with_extra_bytes() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let zstded_content = zstd_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + zstded_content.len() + ) + .as_bytes(), + &zstded_content, + ] + .concat(); + let response_second_part = b"\r\n2ab\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + let err = res.text().await.expect_err("there must be an error"); + assert!(err.is_decode()); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +}