diff --git a/http-body-util/src/combinators/with_trailers.rs b/http-body-util/src/combinators/with_trailers.rs index 9a9e525..fbaa1da 100644 --- a/http-body-util/src/combinators/with_trailers.rs +++ b/http-body-util/src/combinators/with_trailers.rs @@ -41,6 +41,7 @@ pin_project! { PollTrailers { #[pin] trailers: F, + prev_trailers: Option, }, Trailers { trailers: Option, @@ -65,17 +66,43 @@ where let new_state: State<_, _> = match this.state.as_mut().project() { StateProj::PollBody { body, trailers } => match ready!(body.poll_frame(cx)?) { - Some(frame) => { - return Poll::Ready(Some(Ok(frame))); - } + Some(frame) => match frame.into_trailers() { + Ok(prev_trailers) => { + let trailers = trailers.take().unwrap(); + State::PollTrailers { + trailers, + prev_trailers: Some(prev_trailers), + } + } + Err(frame) => { + return Poll::Ready(Some(Ok(frame))); + } + }, None => { let trailers = trailers.take().unwrap(); - State::PollTrailers { trailers } + State::PollTrailers { + trailers, + prev_trailers: None, + } } }, - StateProj::PollTrailers { trailers } => { + StateProj::PollTrailers { + trailers, + prev_trailers, + } => { let trailers = ready!(trailers.poll(cx)?); - State::Trailers { trailers } + match (trailers, prev_trailers.take()) { + (None, None) => return Poll::Ready(None), + (None, Some(trailers)) | (Some(trailers), None) => State::Trailers { + trailers: Some(trailers), + }, + (Some(new_trailers), Some(mut prev_trailers)) => { + prev_trailers.extend(new_trailers); + State::Trailers { + trailers: Some(prev_trailers), + } + } + } } StateProj::Trailers { trailers } => { return Poll::Ready(trailers.take().map(Frame::trailers).map(Ok)); @@ -110,7 +137,7 @@ mod tests { use bytes::Bytes; use http::{HeaderMap, HeaderName, HeaderValue}; - use crate::{BodyExt, Full}; + use crate::{BodyExt, Empty, Full}; #[allow(unused_imports)] use super::*; @@ -149,6 +176,46 @@ mod tests { assert!(unwrap_ready(body.as_mut().poll_frame(&mut cx)).is_none()); } + #[tokio::test] + async fn merges_trailers() { + let mut trailers_1 = HeaderMap::new(); + trailers_1.insert( + HeaderName::from_static("foo"), + HeaderValue::from_static("bar"), + ); + + let mut trailers_2 = HeaderMap::new(); + trailers_2.insert( + HeaderName::from_static("baz"), + HeaderValue::from_static("qux"), + ); + + let body = Empty::::new() + .with_trailers(std::future::ready(Some(Ok::<_, Infallible>( + trailers_1.clone(), + )))) + .with_trailers(std::future::ready(Some(Ok::<_, Infallible>( + trailers_2.clone(), + )))); + + futures_util::pin_mut!(body); + let waker = futures_util::task::noop_waker(); + let mut cx = Context::from_waker(&waker); + + let body_trailers = unwrap_ready(body.as_mut().poll_frame(&mut cx)) + .unwrap() + .unwrap() + .into_trailers() + .unwrap(); + + let mut all_trailers = HeaderMap::new(); + all_trailers.extend(trailers_1); + all_trailers.extend(trailers_2); + assert_eq!(body_trailers, all_trailers); + + assert!(unwrap_ready(body.as_mut().poll_frame(&mut cx)).is_none()); + } + fn unwrap_ready(poll: Poll) -> T { match poll { Poll::Ready(t) => t,