From e86bbaf16581395db54d5c4e271bd655972ad623 Mon Sep 17 00:00:00 2001 From: Andrew Werner Date: Tue, 6 Aug 2024 11:07:15 -0400 Subject: [PATCH] h2-support,h2-tests: add tools to ensure wake This commit adds wrappers around futures::future helpers and augments TestFuture to ensure that the underlying futures are notified before they are polled. This helps to catch bugs where there are missing notify calls or bad handling of the waker. The commit then extends the tests to use these helpers instead of the library functions from futures. It also ammends the client_requests::recv_too_big_headers test to no longer use the tokio spawned tasks that were added in #791. --- tests/h2-support/src/future_ext.rs | 141 ++++++++++++++++++++++++- tests/h2-support/src/prelude.rs | 2 +- tests/h2-tests/tests/client_request.rs | 22 ++-- tests/h2-tests/tests/codec_read.rs | 1 - tests/h2-tests/tests/codec_write.rs | 1 - tests/h2-tests/tests/flow_control.rs | 1 - tests/h2-tests/tests/ping_pong.rs | 1 - tests/h2-tests/tests/prioritization.rs | 1 - tests/h2-tests/tests/push_promise.rs | 16 +-- tests/h2-tests/tests/server.rs | 1 - tests/h2-tests/tests/stream_states.rs | 2 +- 11 files changed, 151 insertions(+), 38 deletions(-) diff --git a/tests/h2-support/src/future_ext.rs b/tests/h2-support/src/future_ext.rs index 9f659b344..cca18c66e 100644 --- a/tests/h2-support/src/future_ext.rs +++ b/tests/h2-support/src/future_ext.rs @@ -1,7 +1,9 @@ -use futures::FutureExt; +use futures::{FutureExt, TryFuture}; use std::future::Future; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; +use std::task::{Context, Poll, Wake, Waker}; /// Future extension helpers that are useful for tests pub trait TestFuture: Future { @@ -15,9 +17,140 @@ pub trait TestFuture: Future { { Drive { driver: self, - future: Box::pin(other), + future: other.wakened(), } } + + fn wakened(self) -> Wakened + where + Self: Sized, + { + Wakened { + future: Box::pin(self), + woken: Arc::new(AtomicBool::new(true)), + } + } +} + +/// Wraps futures::future::join to ensure that the futures are only polled if they are woken. +pub fn join( + future1: Fut1, + future2: Fut2, +) -> futures::future::Join, Wakened> +where + Fut1: Future, + Fut2: Future, +{ + futures::future::join(future1.wakened(), future2.wakened()) +} + +/// Wraps futures::future::join3 to ensure that the futures are only polled if they are woken. +pub fn join3( + future1: Fut1, + future2: Fut2, + future3: Fut3, +) -> futures::future::Join3, Wakened, Wakened> +where + Fut1: Future, + Fut2: Future, + Fut3: Future, +{ + futures::future::join3(future1.wakened(), future2.wakened(), future3.wakened()) +} + +/// Wraps futures::future::join4 to ensure that the futures are only polled if they are woken. +pub fn join4( + future1: Fut1, + future2: Fut2, + future3: Fut3, + future4: Fut4, +) -> futures::future::Join4, Wakened, Wakened, Wakened> +where + Fut1: Future, + Fut2: Future, + Fut3: Future, + Fut4: Future, +{ + futures::future::join4( + future1.wakened(), + future2.wakened(), + future3.wakened(), + future4.wakened(), + ) +} + +/// Wraps futures::future::try_join to ensure that the futures are only polled if they are woken. +pub fn try_join( + future1: Fut1, + future2: Fut2, +) -> futures::future::TryJoin, Wakened> +where + Fut1: futures::future::TryFuture + Future, + Fut2: Future, + Wakened: futures::future::TryFuture, + Wakened: futures::future::TryFuture as TryFuture>::Error>, +{ + futures::future::try_join(future1.wakened(), future2.wakened()) +} + +/// Wraps futures::future::select to ensure that the futures are only polled if they are woken. +pub fn select(future1: A, future2: B) -> futures::future::Select, Wakened> +where + A: Future + Unpin, + B: Future + Unpin, +{ + futures::future::select(future1.wakened(), future2.wakened()) +} + +/// Wraps futures::future::join_all to ensure that the futures are only polled if they are woken. +pub fn join_all(iter: I) -> futures::future::JoinAll> +where + I: IntoIterator, + I::Item: Future, +{ + futures::future::join_all(iter.into_iter().map(|f| f.wakened())) +} + +/// A future that only polls the inner future if it has been woken (after the initial poll). +pub struct Wakened { + future: Pin>, + woken: Arc, +} + +/// A future that only polls the inner future if it has been woken (after the initial poll). +impl Future for Wakened +where + T: Future, +{ + type Output = T::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + if !this.woken.load(std::sync::atomic::Ordering::SeqCst) { + return Poll::Pending; + } + this.woken.store(false, std::sync::atomic::Ordering::SeqCst); + let my_waker = IfWokenWaker { + inner: cx.waker().clone(), + wakened: this.woken.clone(), + }; + let my_waker = Arc::new(my_waker).into(); + let mut cx = Context::from_waker(&my_waker); + this.future.as_mut().poll(&mut cx) + } +} + +impl Wake for IfWokenWaker { + fn wake(self: Arc) { + self.wakened + .store(true, std::sync::atomic::Ordering::SeqCst); + self.inner.wake_by_ref(); + } +} + +struct IfWokenWaker { + inner: Waker, + wakened: Arc, } impl TestFuture for T {} @@ -29,7 +162,7 @@ impl TestFuture for T {} /// This is useful for H2 futures that also require the connection to be polled. pub struct Drive<'a, T, U> { driver: &'a mut T, - future: Pin>, + future: Wakened, } impl<'a, T, U> Future for Drive<'a, T, U> diff --git a/tests/h2-support/src/prelude.rs b/tests/h2-support/src/prelude.rs index c40a518da..3c6e15d75 100644 --- a/tests/h2-support/src/prelude.rs +++ b/tests/h2-support/src/prelude.rs @@ -35,7 +35,7 @@ pub use {bytes, futures, http, tokio::io as tokio_io, tracing, tracing_subscribe pub use futures::{Future, Sink, Stream}; // And our Future extensions -pub use super::future_ext::TestFuture; +pub use super::future_ext::{TestFuture, join, join_all, select, join3, join4, try_join}; // Our client_ext helpers pub use super::client_ext::SendRequestExt; diff --git a/tests/h2-tests/tests/client_request.rs b/tests/h2-tests/tests/client_request.rs index 50be06e66..8e0e599a7 100644 --- a/tests/h2-tests/tests/client_request.rs +++ b/tests/h2-tests/tests/client_request.rs @@ -1,4 +1,4 @@ -use futures::future::{join, join_all, ready, select, Either}; +use futures::future::{ready, Either}; use futures::stream::FuturesUnordered; use futures::StreamExt; use h2_support::prelude::*; @@ -849,7 +849,7 @@ async fn recv_too_big_headers() { }; let client = async move { - let (mut client, conn) = client::Builder::new() + let (mut client, mut conn) = client::Builder::new() .max_header_list_size(10) .handshake::<_, Bytes>(io) .await @@ -863,10 +863,10 @@ async fn recv_too_big_headers() { let req1 = client.send_request(request, true); // Spawn tasks to ensure that the error wakes up tasks that are blocked // waiting for a response. - let req1 = tokio::spawn(async move { + let req1 = async move { let err = req1.expect("send_request").0.await.expect_err("response1"); assert_eq!(err.reason(), Some(Reason::REFUSED_STREAM)); - }); + }; let request = Request::builder() .uri("https://http2.akamai.com/") @@ -874,19 +874,13 @@ async fn recv_too_big_headers() { .unwrap(); let req2 = client.send_request(request, true); - let req2 = tokio::spawn(async move { + let req2 = async move { let err = req2.expect("send_request").0.await.expect_err("response2"); assert_eq!(err.reason(), Some(Reason::REFUSED_STREAM)); - }); + }; - let conn = tokio::spawn(async move { - conn.await.expect("client"); - }); - for err in join_all([req1, req2, conn]).await { - if let Some(err) = err.err().and_then(|err| err.try_into_panic().ok()) { - std::panic::resume_unwind(err); - } - } + + conn.drive(join(req1, req2)).await; }; join(srv, client).await; diff --git a/tests/h2-tests/tests/codec_read.rs b/tests/h2-tests/tests/codec_read.rs index d955e186b..489d16daf 100644 --- a/tests/h2-tests/tests/codec_read.rs +++ b/tests/h2-tests/tests/codec_read.rs @@ -1,4 +1,3 @@ -use futures::future::join; use h2_support::prelude::*; #[tokio::test] diff --git a/tests/h2-tests/tests/codec_write.rs b/tests/h2-tests/tests/codec_write.rs index 0b85a2238..04627cdc9 100644 --- a/tests/h2-tests/tests/codec_write.rs +++ b/tests/h2-tests/tests/codec_write.rs @@ -1,4 +1,3 @@ -use futures::future::join; use h2_support::prelude::*; #[tokio::test] diff --git a/tests/h2-tests/tests/flow_control.rs b/tests/h2-tests/tests/flow_control.rs index dbb933286..e3caaff5f 100644 --- a/tests/h2-tests/tests/flow_control.rs +++ b/tests/h2-tests/tests/flow_control.rs @@ -1,4 +1,3 @@ -use futures::future::{join, join4}; use futures::{StreamExt, TryStreamExt}; use h2_support::prelude::*; use h2_support::util::yield_once; diff --git a/tests/h2-tests/tests/ping_pong.rs b/tests/h2-tests/tests/ping_pong.rs index 0f93578cc..2132c7acf 100644 --- a/tests/h2-tests/tests/ping_pong.rs +++ b/tests/h2-tests/tests/ping_pong.rs @@ -1,5 +1,4 @@ use futures::channel::oneshot; -use futures::future::join; use futures::StreamExt; use h2_support::assert_ping; use h2_support::prelude::*; diff --git a/tests/h2-tests/tests/prioritization.rs b/tests/h2-tests/tests/prioritization.rs index 11d2c2ccf..dd4ed9fea 100644 --- a/tests/h2-tests/tests/prioritization.rs +++ b/tests/h2-tests/tests/prioritization.rs @@ -1,4 +1,3 @@ -use futures::future::{join, select}; use futures::{pin_mut, FutureExt, StreamExt}; use h2_support::prelude::*; diff --git a/tests/h2-tests/tests/push_promise.rs b/tests/h2-tests/tests/push_promise.rs index c2138edcd..fbc375586 100644 --- a/tests/h2-tests/tests/push_promise.rs +++ b/tests/h2-tests/tests/push_promise.rs @@ -1,6 +1,4 @@ -use std::iter::FromIterator; - -use futures::{future::join, FutureExt as _, StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt}; use h2_support::prelude::*; #[tokio::test] @@ -52,15 +50,9 @@ async fn recv_push_works() { let ps: Vec<_> = p.collect().await; assert_eq!(1, ps.len()) }; - // Use a FuturesUnordered to poll both tasks but only poll them - // if they have been notified. - let tasks = futures::stream::FuturesUnordered::from_iter([ - check_resp_status.boxed(), - check_pushed_response.boxed(), - ]) - .collect::<()>(); - - h2.drive(tasks).await; + + h2.drive(join(check_resp_status, check_pushed_response)) + .await; }; join(mock, h2).await; diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index 7155b5868..c266bc2d5 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -1,6 +1,5 @@ #![deny(warnings)] -use futures::future::join; use futures::StreamExt; use h2_support::prelude::*; use tokio::io::AsyncWriteExt; diff --git a/tests/h2-tests/tests/stream_states.rs b/tests/h2-tests/tests/stream_states.rs index 05a96a0f5..9a377d798 100644 --- a/tests/h2-tests/tests/stream_states.rs +++ b/tests/h2-tests/tests/stream_states.rs @@ -1,6 +1,6 @@ #![deny(warnings)] -use futures::future::{join, join3, lazy, try_join}; +use futures::future::lazy; use futures::{FutureExt, StreamExt, TryStreamExt}; use h2_support::prelude::*; use h2_support::util::yield_once;