From 6104a314be4245d35338f6fb830b3d178cd03d0c Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Fri, 1 Dec 2023 09:15:36 -0500 Subject: [PATCH] backport Collect --- src/collect.rs | 222 +++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 11 +++ 2 files changed, 233 insertions(+) create mode 100644 src/collect.rs diff --git a/src/collect.rs b/src/collect.rs new file mode 100644 index 0000000..b065fff --- /dev/null +++ b/src/collect.rs @@ -0,0 +1,222 @@ +use std::{ + collections::VecDeque, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use super::Body; + +use bytes::{Buf, Bytes}; +use http::HeaderMap; +use pin_project_lite::pin_project; + +pin_project! { + /// Future that resolves into a [`Collected`]. + pub struct Collect + where + T: Body, + { + #[pin] + body: T, + collected: Option>, + is_data_done: bool, + } +} + +impl Collect { + pub(crate) fn new(body: T) -> Self { + Self { + body, + collected: Some(Collected::default()), + is_data_done: false, + } + } +} + +impl Future for Collect { + type Output = Result, T::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut me = self.project(); + + loop { + if !*me.is_data_done { + match me.body.as_mut().poll_data(cx) { + Poll::Ready(Some(Ok(data))) => { + me.collected.as_mut().unwrap().push_data(data); + } + Poll::Ready(Some(Err(err))) => { + return Poll::Ready(Err(err)); + } + Poll::Ready(None) => { + *me.is_data_done = true; + } + Poll::Pending => return Poll::Pending, + } + } else { + match me.body.as_mut().poll_trailers(cx) { + Poll::Ready(Ok(Some(trailers))) => { + me.collected.as_mut().unwrap().push_trailers(trailers); + break; + } + Poll::Ready(Err(err)) => { + return Poll::Ready(Err(err)); + } + Poll::Ready(Ok(None)) => break, + Poll::Pending => return Poll::Pending, + } + } + } + + Poll::Ready(Ok(me.collected.take().expect("polled after complete"))) + } +} + +/// A collected body produced by [`Body::collect`] which collects all the DATA frames +/// and trailers. +#[derive(Debug)] +pub struct Collected { + bufs: BufList, + trailers: Option, +} + +impl Collected { + /// If there is a trailers frame buffered, returns a reference to it. + /// + /// Returns `None` if the body contained no trailers. + pub fn trailers(&self) -> Option<&HeaderMap> { + self.trailers.as_ref() + } + + /// Aggregate this buffered into a [`Buf`]. + pub fn aggregate(self) -> impl Buf { + self.bufs + } + + /// Convert this body into a [`Bytes`]. + pub fn to_bytes(mut self) -> Bytes { + self.bufs.copy_to_bytes(self.bufs.remaining()) + } + + fn push_data(&mut self, data: B) { + // Only push this frame if it has some data in it, to avoid crashing on + // `BufList::push`. + if data.has_remaining() { + self.bufs.push(data); + } + } + + fn push_trailers(&mut self, trailers: HeaderMap) { + if let Some(current) = &mut self.trailers { + current.extend(trailers); + } else { + self.trailers = Some(trailers); + } + } +} + +impl Default for Collected { + fn default() -> Self { + Self { + bufs: BufList::default(), + trailers: None, + } + } +} + +impl Unpin for Collected {} + +#[derive(Debug)] +struct BufList { + bufs: VecDeque, +} + +impl BufList { + #[inline] + pub(crate) fn push(&mut self, buf: T) { + debug_assert!(buf.has_remaining()); + self.bufs.push_back(buf); + } + + /* + #[inline] + pub(crate) fn pop(&mut self) -> Option { + self.bufs.pop_front() + } + */ +} + +impl Buf for BufList { + #[inline] + fn remaining(&self) -> usize { + self.bufs.iter().map(|buf| buf.remaining()).sum() + } + + #[inline] + fn chunk(&self) -> &[u8] { + self.bufs.front().map(Buf::chunk).unwrap_or_default() + } + + #[inline] + fn advance(&mut self, mut cnt: usize) { + while cnt > 0 { + { + let front = &mut self.bufs[0]; + let rem = front.remaining(); + if rem > cnt { + front.advance(cnt); + return; + } else { + front.advance(rem); + cnt -= rem; + } + } + self.bufs.pop_front(); + } + } + + #[inline] + fn chunks_vectored<'t>(&'t self, dst: &mut [std::io::IoSlice<'t>]) -> usize { + if dst.is_empty() { + return 0; + } + let mut vecs = 0; + for buf in &self.bufs { + vecs += buf.chunks_vectored(&mut dst[vecs..]); + if vecs == dst.len() { + break; + } + } + vecs + } + + #[inline] + fn copy_to_bytes(&mut self, len: usize) -> Bytes { + use bytes::{BufMut, BytesMut}; + // Our inner buffer may have an optimized version of copy_to_bytes, and if the whole + // request can be fulfilled by the front buffer, we can take advantage. + match self.bufs.front_mut() { + Some(front) if front.remaining() == len => { + let b = front.copy_to_bytes(len); + self.bufs.pop_front(); + b + } + Some(front) if front.remaining() > len => front.copy_to_bytes(len), + _ => { + assert!(len <= self.remaining(), "`len` greater than remaining"); + let mut bm = BytesMut::with_capacity(len); + bm.put(self.take(len)); + bm.freeze() + } + } + } +} + +impl Default for BufList { + fn default() -> Self { + BufList { + bufs: VecDeque::new(), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 84efd91..92064c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ //! //! [`Body`]: trait.Body.html +mod collect; mod empty; mod full; mod limited; @@ -21,6 +22,7 @@ mod size_hint; pub mod combinators; +pub use self::collect::Collected; pub use self::empty::Empty; pub use self::full::Full; pub use self::limited::{LengthLimitError, Limited}; @@ -118,6 +120,15 @@ pub trait Body { MapErr::new(self, f) } + /// Turn this body into [`Collected`] body which will collect all the DATA frames + /// and trailers. + fn collect(self) -> crate::collect::Collect + where + Self: Sized, + { + collect::Collect::new(self) + } + /// Turn this body into a boxed trait object. fn boxed(self) -> BoxBody where