diff --git a/ipa-core/src/helpers/transport/stream/buffered.rs b/ipa-core/src/helpers/transport/stream/buffered.rs index 8f68f916c..7efc12112 100644 --- a/ipa-core/src/helpers/transport/stream/buffered.rs +++ b/ipa-core/src/helpers/transport/stream/buffered.rs @@ -99,15 +99,18 @@ mod tests { num::NonZeroUsize, pin::Pin, sync::{Arc, Mutex}, + task, task::Poll, }; use bytes::Bytes; use futures::{stream::TryStreamExt, FutureExt, Stream, StreamExt}; + use pin_project::pin_project; use proptest::{ prop_compose, proptest, strategy::{Just, Strategy}, }; + use task::Context; use crate::{ error::BoxError, helpers::transport::stream::buffered::BufferedBytesStream, @@ -214,10 +217,10 @@ mod tests { chunk: usize, } + #[pin_project] struct FallibleTestStream { - total_size: usize, - remaining: usize, - chunk: usize, + #[pin] + inner: TestStream, error_after: usize, } @@ -231,9 +234,11 @@ mod tests { fn fallible_stream(total_size: usize, chunk: usize, error_after: usize) -> FallibleTestStream { FallibleTestStream { - total_size, - remaining: total_size, - chunk, + inner: TestStream { + total_size, + remaining: total_size, + chunk, + }, error_after, } } @@ -241,10 +246,7 @@ mod tests { impl Stream for TestStream { type Item = Result; - fn poll_next( - mut self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { if self.remaining == 0 { return Poll::Ready(None); } @@ -261,23 +263,19 @@ mod tests { impl Stream for FallibleTestStream { type Item = Result; - fn poll_next( - mut self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - if self.remaining == 0 { - return Poll::Ready(None); - } - let next_chunk_size = min(self.remaining, self.chunk); - let next_chunk = (0..next_chunk_size) - .map(|v| u8::try_from(v % 256).unwrap()) - .collect::>(); - - self.remaining -= next_chunk_size; - if self.total_size - self.remaining >= self.error_after { - Poll::Ready(Some(Err("error".into()))) - } else { - Poll::Ready(Some(Ok(Bytes::from(next_chunk)))) + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + match this.inner.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(bytes))) => { + if this.inner.total_size - this.inner.remaining >= *this.error_after { + Poll::Ready(Some(Err("error".into()))) + } else { + Poll::Ready(Some(Ok(bytes))) + } + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, } } }