diff --git a/http-body-util/Cargo.toml b/http-body-util/Cargo.toml index e74a31b..88d2d5c 100644 --- a/http-body-util/Cargo.toml +++ b/http-body-util/Cargo.toml @@ -27,9 +27,15 @@ categories = ["web-programming"] [dependencies] bytes = "1" +futures-util = { version = "0.3", default-features = false, features = ["alloc"] } http = "0.2" http-body = { path = "../http-body" } pin-project-lite = "0.2" +tokio = { version = "1", features = ["time"], optional = true } [dev-dependencies] -tokio = { version = "1", features = ["macros", "rt"] } +tokio = { version = "1", features = ["macros", "rt", "test-util"] } + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] diff --git a/http-body-util/src/lib.rs b/http-body-util/src/lib.rs index 2e55b27..04df6c5 100644 --- a/http-body-util/src/lib.rs +++ b/http-body-util/src/lib.rs @@ -4,6 +4,7 @@ unreachable_pub, rustdoc::broken_intra_doc_links )] +#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))] #![cfg_attr(test, deny(warnings))] //! Utilities for [`http_body::Body`]. @@ -16,11 +17,19 @@ pub mod combinators; mod empty; mod full; mod limited; +mod stream; + +#[cfg(feature = "tokio")] +mod throttle; use self::combinators::{BoxBody, MapData, MapErr, UnsyncBoxBody}; pub use self::empty::Empty; pub use self::full::Full; pub use self::limited::{LengthLimitError, Limited}; +pub use self::stream::StreamBody; + +#[cfg(feature = "tokio")] +pub use self::throttle::Throttle; /// An extension trait for [`http_body::Body`] adding various combinators and adapters pub trait BodyExt: http_body::Body { diff --git a/http-body-util/src/limited.rs b/http-body-util/src/limited.rs index 3f3ceae..700b0af 100644 --- a/http-body-util/src/limited.rs +++ b/http-body-util/src/limited.rs @@ -114,7 +114,7 @@ impl Error for LengthLimitError {} #[cfg(test)] mod tests { use super::*; - use crate::Full; + use crate::{Full, StreamBody}; use bytes::Bytes; use std::convert::Infallible; @@ -150,40 +150,25 @@ mod tests { assert!(matches!(error.downcast_ref(), Some(LengthLimitError))); } - struct Chunky(&'static [&'static [u8]]); - - impl Body for Chunky { - type Data = &'static [u8]; - type Error = Infallible; - - fn poll_data( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll>> { - let mut this = self; - match this.0.split_first().map(|(&head, tail)| (Ok(head), tail)) { - Some((data, new_tail)) => { - this.0 = new_tail; - - Poll::Ready(Some(data)) - } - None => Poll::Ready(None), - } - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(Some(HeaderMap::new()))) - } + fn body_from_iter(into_iter: I) -> impl Body + where + I: IntoIterator, + I::Item: Into + 'static, + I::IntoIter: Send + 'static, + { + let iter = into_iter + .into_iter() + .map(Into::into) + .map(Ok::<_, Infallible>); + + StreamBody::new(futures_util::stream::iter(iter)) } #[tokio::test] async fn read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk( ) { - const DATA: &[&[u8]] = &[b"testing ", b"a string that is too long"]; - let inner = Chunky(DATA); + const DATA: [&[u8]; 2] = [b"testing ", b"a string that is too long"]; + let inner = body_from_iter(DATA); let body = &mut Limited::new(inner, 8); let mut hint = SizeHint::new(); @@ -201,8 +186,8 @@ mod tests { #[tokio::test] async fn read_for_chunked_body_over_limit_on_first_chunk_returns_error() { - const DATA: &[&[u8]] = &[b"testing a string", b" that is too long"]; - let inner = Chunky(DATA); + const DATA: [&[u8]; 2] = [b"testing a string", b" that is too long"]; + let inner = body_from_iter(DATA); let body = &mut Limited::new(inner, 8); let mut hint = SizeHint::new(); @@ -215,8 +200,8 @@ mod tests { #[tokio::test] async fn read_for_chunked_body_under_limit_is_okay() { - const DATA: &[&[u8]] = &[b"test", b"ing!"]; - let inner = Chunky(DATA); + const DATA: [&[u8]; 2] = [b"test", b"ing!"]; + let inner = body_from_iter(DATA); let body = &mut Limited::new(inner, 8); let mut hint = SizeHint::new(); @@ -236,11 +221,30 @@ mod tests { assert!(matches!(body.data().await, None)); } + struct SomeTrailers; + + impl Body for SomeTrailers { + type Data = Bytes; + type Error = Infallible; + + fn poll_data( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + Poll::Ready(None) + } + + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + Poll::Ready(Ok(Some(HeaderMap::new()))) + } + } + #[tokio::test] async fn read_for_trailers_propagates_inner_trailers() { - const DATA: &[&[u8]] = &[b"test", b"ing!"]; - let inner = Chunky(DATA); - let body = &mut Limited::new(inner, 8); + let body = &mut Limited::new(SomeTrailers, 8); let trailers = body.trailers().await.unwrap(); assert_eq!(trailers, Some(HeaderMap::new())) } diff --git a/http-body-util/src/stream.rs b/http-body-util/src/stream.rs new file mode 100644 index 0000000..3f4e617 --- /dev/null +++ b/http-body-util/src/stream.rs @@ -0,0 +1,85 @@ +use bytes::Buf; +use futures_util::stream::Stream; +use http::HeaderMap; +use http_body::Body; +use pin_project_lite::pin_project; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +pin_project! { + /// A body created from a `Stream`. + #[derive(Clone, Copy, Debug)] + pub struct StreamBody { + #[pin] + stream: S, + } +} + +impl StreamBody { + /// Create a new `StreamBody`. + pub fn new(stream: S) -> Self { + Self { stream } + } +} + +impl Body for StreamBody +where + S: Stream>, + D: Buf, +{ + type Data = D; + type Error = E; + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + self.project().stream.poll_next(cx) + } + + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + Poll::Ready(Ok(None)) + } +} + +impl Stream for StreamBody { + type Item = S::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} + +#[cfg(test)] +mod tests { + use crate::StreamBody; + use bytes::Bytes; + use http_body::Body; + use std::convert::Infallible; + + #[tokio::test] + async fn body_from_stream() { + let chunks: Vec> = vec![ + Ok(Bytes::from(vec![1])), + Ok(Bytes::from(vec![2])), + Ok(Bytes::from(vec![3])), + ]; + let stream = futures_util::stream::iter(chunks); + let mut body = StreamBody::new(stream); + + assert_eq!(body.data().await.unwrap().unwrap().as_ref(), [1]); + assert_eq!(body.data().await.unwrap().unwrap().as_ref(), [2]); + assert_eq!(body.data().await.unwrap().unwrap().as_ref(), [3]); + + assert!(body.data().await.is_none()); + } +} diff --git a/http-body-util/src/throttle.rs b/http-body-util/src/throttle.rs new file mode 100644 index 0000000..2083b9d --- /dev/null +++ b/http-body-util/src/throttle.rs @@ -0,0 +1,163 @@ +use bytes::Buf; +use http::HeaderMap; +use http_body::{Body, SizeHint}; +use pin_project_lite::pin_project; +use std::{ + convert::{TryFrom, TryInto}, + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; +use tokio::time::{sleep, Instant, Sleep}; + +#[derive(Debug)] +enum State { + Waiting(Pin>, Instant), + Ready(Instant), + Init, +} + +pin_project! { + /// A throttled body. + #[derive(Debug)] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))] + pub struct Throttle { + #[pin] + inner: B, + state: State, + cursor: f64, + byte_rate: f64, + } +} + +impl Throttle { + /// Create a new `Throttle`. + /// + /// # Panic + /// + /// Will panic if milliseconds in `duration` is larger than `u32::MAX`. + pub fn new(body: B, duration: Duration, bytes: u32) -> Self { + let bytes = f64::from(bytes); + let duration = f64::from(u32::try_from(duration.as_millis()).expect("duration too large")); + + let byte_rate = bytes / duration; + + Self { + inner: body, + state: State::Init, + cursor: 0.0, + byte_rate, + } + } +} + +impl Body for Throttle { + type Data = B::Data; + type Error = B::Error; + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let mut this = self.project(); + + loop { + match this.state { + State::Waiting(sleep, time) => match sleep.as_mut().poll(cx) { + Poll::Ready(()) => { + let byte_rate = *this.byte_rate; + let mut elapsed = to_f64(time.elapsed().as_millis()); + + if elapsed > 2000.0 { + elapsed = 2000.0; + } + + *this.cursor += elapsed * byte_rate; + *this.state = State::Ready(Instant::now()); + } + Poll::Pending => return Poll::Pending, + }, + State::Ready(time) => match this.inner.as_mut().poll_data(cx) { + Poll::Ready(Some(Ok(data))) => { + let byte_count = to_f64(data.remaining()); + let byte_rate = *this.byte_rate; + + *this.cursor -= byte_count; + + if *this.cursor <= 0.0 { + let wait_millis = this.cursor.abs() / byte_rate; + let duration = Duration::from_millis(wait_millis as u64); + + *this.state = State::Waiting(Box::pin(sleep(duration)), *time); + } + + return Poll::Ready(Some(Ok(data))); + } + poll_result => return poll_result, + }, + State::Init => *this.state = State::Ready(Instant::now()), + } + } + } + + fn poll_trailers( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + self.project().inner.poll_trailers(cx) + } + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + fn size_hint(&self) -> SizeHint { + self.inner.size_hint() + } +} + +fn to_f64(n: impl TryInto) -> f64 { + f64::from(n.try_into().unwrap_or(u32::MAX)) +} + +#[cfg(test)] +mod tests { + use crate::{StreamBody, Throttle}; + use bytes::Bytes; + use http_body::Body; + use std::{convert::Infallible, time::Duration}; + use tokio::time::Instant; + + #[tokio::test(start_paused = true)] + async fn per_second_256() { + let start = Instant::now(); + + let chunks: Vec> = vec![ + Ok(Bytes::from(vec![0u8; 128])), + Ok(Bytes::from(vec![0u8; 128])), + Ok(Bytes::from(vec![0u8; 256])), + Ok(Bytes::from(vec![0u8; 128])), + Ok(Bytes::from(vec![0u8; 128])), + ]; + let stream = futures_util::stream::iter(chunks); + let mut body = Throttle::new(StreamBody::new(stream), Duration::from_secs(1), 256); + + assert_eq!(body.data().await.unwrap().unwrap().as_ref(), [0u8; 128]); + assert!(start.elapsed().is_zero()); // Throttling starts after first chunk. + + assert_eq!(body.data().await.unwrap().unwrap().as_ref(), [0u8; 128]); + assert_eq!(start.elapsed(), Duration::from_millis(500)); + + assert_eq!(body.data().await.unwrap().unwrap().as_ref(), [0u8; 256]); + assert_eq!(start.elapsed(), Duration::from_millis(1000)); + + assert_eq!(body.data().await.unwrap().unwrap().as_ref(), [0u8; 128]); + assert_eq!(start.elapsed(), Duration::from_millis(2000)); + + assert_eq!(body.data().await.unwrap().unwrap().as_ref(), [0u8; 128]); + assert_eq!(start.elapsed(), Duration::from_millis(2500)); + + assert!(body.data().await.is_none()); + } +}