Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add keepalive timeout when connection idle #827

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ atomic-waker = "1.0.0"
futures-core = { version = "0.3", default-features = false }
futures-sink = { version = "0.3", default-features = false }
tokio-util = { version = "0.7.1", features = ["codec", "io"] }
tokio = { version = "1", features = ["io-util"] }
tokio = { version = "1", features = ["io-util", "time"] }
bytes = "1"
http = "1"
tracing = { version = "0.1.35", default-features = false, features = ["std"] }
Expand Down
10 changes: 9 additions & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ pub struct Builder {
///
/// When this gets exceeded, we issue GOAWAYs.
local_max_error_reset_streams: Option<usize>,

keepalive_timeout: Option<Duration>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -580,7 +582,6 @@ where
}
}

#[cfg(feature = "unstable")]
impl<B> SendRequest<B>
where
B: Buf,
Expand Down Expand Up @@ -661,6 +662,7 @@ impl Builder {
initial_target_connection_window_size: None,
initial_max_send_streams: usize::MAX,
settings: Default::default(),
keepalive_timeout: None,
stream_id: 1.into(),
local_max_error_reset_streams: Some(proto::DEFAULT_LOCAL_RESET_COUNT_MAX),
}
Expand Down Expand Up @@ -996,6 +998,11 @@ impl Builder {
self
}

/// Sets the duration connection should be closed when there no stream.
pub fn keepalive_timeout(&mut self, dur: Duration) -> &mut Self {
self.keepalive_timeout = Some(dur);
self
}
/// Sets the maximum number of local resets due to protocol errors made by the remote end.
///
/// Invalid frames and many other protocol errors will lead to resets being generated for those streams.
Expand Down Expand Up @@ -1332,6 +1339,7 @@ where
max_send_buffer_size: builder.max_send_buffer_size,
reset_stream_duration: builder.reset_stream_duration,
reset_stream_max: builder.reset_stream_max,
keepalive_timeout: builder.keepalive_timeout,
remote_reset_stream_max: builder.pending_accept_reset_stream_max,
local_error_reset_streams_max: builder.local_max_error_reset_streams,
settings: builder.settings.clone(),
Expand Down
3 changes: 3 additions & 0 deletions src/frame/reason.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ impl Reason {
pub const INADEQUATE_SECURITY: Reason = Reason(12);
/// The endpoint requires that HTTP/1.1 be used instead of HTTP/2.
pub const HTTP_1_1_REQUIRED: Reason = Reason(13);
/// The endpoint reach keepalive timeout
pub const KEEPALIVE_TIMEOUT: Reason = Reason(14);

/// Get a string description of the error code.
pub fn description(&self) -> &str {
Expand All @@ -79,6 +81,7 @@ impl Reason {
11 => "detected excessive load generating behavior",
12 => "security properties do not meet minimum requirements",
13 => "endpoint requires HTTP/1.1",
14 => "keepalive timeout reached",
_ => "unknown reason",
}
}
Expand Down
44 changes: 39 additions & 5 deletions src/proto/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ use crate::proto::*;

use bytes::Bytes;
use futures_core::Stream;
use std::future::Future;
use std::io;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::AsyncRead;
use tokio::time::Sleep;

/// An H2 connection
#[derive(Debug)]
Expand Down Expand Up @@ -57,6 +59,9 @@ where
/// A `tracing` span tracking the lifetime of the connection.
span: tracing::Span,

keepalive: Option<Pin<Box<Sleep>>>,
keepalive_timeout: Option<Duration>,

/// Client or server
_phantom: PhantomData<P>,
}
Expand All @@ -82,6 +87,7 @@ pub(crate) struct Config {
pub reset_stream_max: usize,
pub remote_reset_stream_max: usize,
pub local_error_reset_streams_max: Option<usize>,
pub keepalive_timeout: Option<Duration>,
pub settings: frame::Settings,
}

Expand Down Expand Up @@ -135,6 +141,8 @@ where
ping_pong: PingPong::new(),
settings: Settings::new(config.settings),
streams,
keepalive: None,
keepalive_timeout: config.keepalive_timeout,
span: tracing::debug_span!("Connection", peer = %P::NAME),
_phantom: PhantomData,
},
Expand Down Expand Up @@ -173,6 +181,10 @@ where
pub(crate) fn max_recv_streams(&self) -> usize {
self.inner.streams.max_recv_streams()
}
/// Returns the number of active stream
pub(crate) fn active_streams(&self) -> usize {
self.inner.streams.num_active_streams()
}

#[cfg(feature = "unstable")]
pub fn num_wired_streams(&self) -> usize {
Expand Down Expand Up @@ -263,30 +275,52 @@ where
let _e = span.enter();
let span = tracing::trace_span!("poll");
let _e = span.enter();

loop {
'outer: loop {
tracing::trace!(connection.state = ?self.inner.state);
// TODO: probably clean up this glob of code
match self.inner.state {
// When open, continue to poll a frame
State::Open => {
let result = match self.poll2(cx) {
Poll::Ready(result) => result,
Poll::Ready(result) => {
self.inner.keepalive = None;
result
}
// The connection is not ready to make progress
Poll::Pending => {
// Ensure all window updates have been sent.
//
// This will also handle flushing `self.codec`
ready!(self.inner.streams.poll_complete(cx, &mut self.codec))?;

if (self.inner.error.is_some()
|| self.inner.go_away.should_close_on_idle())
&& !self.inner.streams.has_streams()
{
self.inner.as_dyn().go_away_now(Reason::NO_ERROR);
continue;
}

if !self.inner.streams.has_streams() {
loop {
match (
self.inner.keepalive.as_mut(),
self.inner.keepalive_timeout,
) {
(Some(sleep), _) => {
ready!(sleep.as_mut().poll(cx));
self.inner
.as_dyn()
.go_away_now(Reason::KEEPALIVE_TIMEOUT);
continue 'outer;
}
(None, Some(timeout)) => {
self.inner
.keepalive
.replace(Box::pin(tokio::time::sleep(timeout)));
}
_ => break,
}
}
}
return Poll::Pending;
}
};
Expand Down
1 change: 0 additions & 1 deletion src/proto/streams/streams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,6 @@ where
self.inner.lock().unwrap().counts.max_recv_streams()
}

#[cfg(feature = "unstable")]
pub fn num_active_streams(&self) -> usize {
let me = self.inner.lock().unwrap();
me.store.num_active_streams()
Expand Down
20 changes: 19 additions & 1 deletion src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ pub struct Builder {
///
/// When this gets exceeded, we issue GOAWAYs.
local_max_error_reset_streams: Option<usize>,

/// Keepalive timeout
keepalive_timeout: Option<Duration>,
}

/// Send a response back to the client
Expand Down Expand Up @@ -581,6 +584,15 @@ where
self.connection.max_recv_streams()
}

/// Returns whether has stream alive
pub fn has_streams_or_other_references(&self) -> bool {
self.connection.has_streams_or_other_references()
}
/// Returns the number of current active stream.
pub fn active_stream(&self) -> usize {
self.connection.active_streams()
}

// Could disappear at anytime.
#[doc(hidden)]
#[cfg(feature = "unstable")]
Expand Down Expand Up @@ -650,7 +662,7 @@ impl Builder {
settings: Settings::default(),
initial_target_connection_window_size: None,
max_send_buffer_size: proto::DEFAULT_MAX_SEND_BUFFER_SIZE,

keepalive_timeout: None,
local_max_error_reset_streams: Some(proto::DEFAULT_LOCAL_RESET_COUNT_MAX),
}
}
Expand Down Expand Up @@ -1015,6 +1027,11 @@ impl Builder {
self
}

/// Sets the duration connection should be closed when there no stream.
pub fn keepalive_timeout(&mut self, dur: Duration) -> &mut Self {
self.keepalive_timeout = Some(dur);
self
}
/// Enables the [extended CONNECT protocol].
///
/// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4
Expand Down Expand Up @@ -1379,6 +1396,7 @@ where
initial_max_send_streams: 0,
max_send_buffer_size: self.builder.max_send_buffer_size,
reset_stream_duration: self.builder.reset_stream_duration,
keepalive_timeout: self.builder.keepalive_timeout,
reset_stream_max: self.builder.reset_stream_max,
remote_reset_stream_max: self.builder.pending_accept_reset_stream_max,
local_error_reset_streams_max: self
Expand Down
35 changes: 35 additions & 0 deletions tests/h2-tests/tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,41 @@ async fn server_builder_set_max_concurrent_streams() {
join(client, h2).await;
}

#[tokio::test]
async fn server_builder_set_keepalive_timeout() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let h1 = async {
let settings = client.assert_server_handshake().await;
assert_default_settings!(settings);
client
.send_frame(
frames::headers(1)
.request("GET", "https://example.com/")
.eos(),
)
.await;
client
.recv_frame(frames::headers(1).response(200).eos())
.await;
};

let mut builder = server::Builder::new();
builder.keepalive_timeout(Duration::from_secs(2));
let h2 = async move {
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
let (req, mut stream) = srv.next().await.unwrap().unwrap();
assert_eq!(req.method(), &http::Method::GET);

let rsp = http::Response::builder().status(200).body(()).unwrap();
let res = stream.send_response(rsp, true).unwrap();
drop(res);
let r1 = srv.accept().await;
println!("rrr {r1:?}");
assert!(r1.is_some_and(|f| f.is_err_and(|f| f.is_go_away())));
};
join(h1, h2).await;
}
#[tokio::test]
async fn serve_request() {
h2_support::trace_init!();
Expand Down