Skip to content

Commit

Permalink
add future-ext
Browse files Browse the repository at this point in the history
I find the syntax for `tokio::time::timeout` to be quite ugly, since it
breaks the typical chained method calls. With the FutureExt trait we can
just call `.with_timeout(..)` instead which seems to be a cleaner way to
do it.
  • Loading branch information
TroyKomodo committed Jan 5, 2025
1 parent e4800b4 commit 8b5bc9b
Show file tree
Hide file tree
Showing 14 changed files with 106 additions and 27 deletions.
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ members = [
"crates/rtmp",
"crates/transmuxer",
"dev-tools/xtask",
"crates/future-ext",
]

resolver = "2"
Expand Down
1 change: 1 addition & 0 deletions crates/bytesio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ scuffle-workspace-hack.workspace = true

[dev-dependencies]
tokio = { version = "1.36", features = ["full"] }
scuffle-future-ext = { path = "../future-ext" }
7 changes: 5 additions & 2 deletions crates/bytesio/src/tests/errors.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use scuffle_future_ext::FutureExt;

use crate::bytesio_errors::BytesIOError;

#[tokio::test]
async fn test_timeout_error_display() {
let err = tokio::time::timeout(std::time::Duration::from_millis(100), async {
let err = async {
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
})
}
.with_timeout(std::time::Duration::from_millis(100))
.await
.unwrap_err();
let bytes_io_error = BytesIOError::from(err);
Expand Down
9 changes: 9 additions & 0 deletions crates/future-ext/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[package]
name = "scuffle-future-ext"
version = "0.1.0"
edition = "2021"
license = "MIT OR Apache-2.0"

[dependencies]
tokio = { version = "1", features = ["time"] }
scuffle-workspace-hack.workspace = true
26 changes: 26 additions & 0 deletions crates/future-ext/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/// The [`FutureExt`] trait is a trait that provides a more ergonomic way to
/// extend futures with additional functionality. Similar to the [`IteratorExt`]
/// trait, but for futures.
pub trait FutureExt {
/// Attach a timeout to the future.
///
/// The timeout is relative to the current time.
fn with_timeout(self, duration: tokio::time::Duration) -> tokio::time::Timeout<Self>
where
Self: Sized;

/// Similar to `with_timeout`, but the timeout is absolute.
fn with_timeout_at(self, deadline: tokio::time::Instant) -> tokio::time::Timeout<Self>
where
Self: Sized;
}

impl<F: std::future::Future> FutureExt for F {
fn with_timeout(self, duration: tokio::time::Duration) -> tokio::time::Timeout<Self> {
tokio::time::timeout(duration, self)
}

fn with_timeout_at(self, deadline: tokio::time::Instant) -> tokio::time::Timeout<Self> {
tokio::time::timeout_at(deadline, self)
}
}
2 changes: 1 addition & 1 deletion crates/http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ itoa = { version = "1" }
smallvec = { version = "1" }
spin = { version = "0.9" }
async-trait = { version = "0.1" }

scuffle-future-ext = { path = "../future-ext" }
# For extra services features
tower-service = { version = "0.3", optional = true }
axum-core = { version = "0.4", optional = true }
Expand Down
5 changes: 3 additions & 2 deletions crates/http/src/backend/quic/quinn/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use bytes::Bytes;
use h3::server::RequestStream;
use http::Response;
use scuffle_context::ContextFutExt;
use scuffle_future_ext::FutureExt;
#[cfg(feature = "http3-webtransport")]
use scuffle_h3_webtransport::server::WebTransportUpgradePending;

Expand Down Expand Up @@ -122,7 +123,7 @@ async fn serve_handle_inner(
.with_context(&ctx);

let Some(connection) = if let Some(timeout) = config.handshake_timeout {
tokio::time::timeout(timeout, conn).await.with_config(ErrorConfig {
conn.with_timeout(timeout).await.with_config(ErrorConfig {
context: "quinn handshake",
scope: ErrorScope::Connection,
severity: ErrorSeverity::Debug,
Expand All @@ -147,7 +148,7 @@ async fn serve_handle_inner(
.with_context(&ctx);

if let Some(timeout) = config.handshake_timeout {
tokio::time::timeout(timeout, fut).await.with_config(ErrorConfig {
fut.with_timeout(timeout).await.with_config(ErrorConfig {
context: "quinn handshake",
scope: ErrorScope::Connection,
severity: ErrorSeverity::Debug,
Expand Down
11 changes: 7 additions & 4 deletions crates/http/src/backend/tcp/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ async fn serve_stream_inner(
match tls_acceptor {
#[cfg(feature = "tls-rustls")]
Some(acceptor) => {
use crate::error::{ErrorConfig, ErrorKind, ErrorScope, ErrorSeverity, ResultErrorExt};
use scuffle_future_ext::FutureExt;

use crate::error::{ErrorConfig, ErrorKind, ErrorScope, ErrorSeverity, ResultErrorExt};
let Some(stream) = async {
// We should read a bit of the stream to see if they are attempting to use TLS
// or not. This is so we can immediately return a bad request if they arent
Expand All @@ -115,7 +116,7 @@ async fn serve_stream_inner(
let is_tls = util::is_tls(&mut stream, handle);

let is_tls = if let Some(timeout) = config.handshake_timeout {
tokio::time::timeout(timeout, is_tls).await.with_config(ErrorConfig {
is_tls.with_timeout(timeout).await.with_config(ErrorConfig {
context: "tls handshake",
scope: ErrorScope::Connection,
severity: ErrorSeverity::Debug,
Expand All @@ -135,7 +136,7 @@ async fn serve_stream_inner(
let lazy = tokio_rustls::LazyConfigAcceptor::new(Default::default(), stream);

let accepted = if let Some(timeout) = config.handshake_timeout {
tokio::time::timeout(timeout, lazy).await.with_config(ErrorConfig {
lazy.with_timeout(timeout).await.with_config(ErrorConfig {
context: "tls handshake",
scope: ErrorScope::Connection,
severity: ErrorSeverity::Debug,
Expand All @@ -154,7 +155,9 @@ async fn serve_stream_inner(
};

let stream = if let Some(timeout) = config.handshake_timeout {
tokio::time::timeout(timeout, accepted.into_stream(tls_config))
accepted
.into_stream(tls_config)
.with_timeout(timeout)
.await
.with_config(ErrorConfig {
context: "tls handshake",
Expand Down
1 change: 1 addition & 0 deletions crates/rtmp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ tracing = "0.1"
bytesio = { path = "../bytesio", features = ["default"] }
amf0 = { path = "../amf0" }
scuffle-workspace-hack.workspace = true
scuffle-future-ext = { path = "../future-ext" }

[dev-dependencies]
tokio = { version = "1.36", features = ["full"] }
Expand Down
18 changes: 13 additions & 5 deletions crates/rtmp/src/session/server_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use bytes::Bytes;
use bytesio::bytes_writer::BytesWriter;
use bytesio::bytesio::{AsyncReadWrite, BytesIO};
use bytesio::bytesio_errors::BytesIOError;
use scuffle_future_ext::FutureExt;
use tokio::sync::oneshot;

use super::define::RtmpCommand;
Expand All @@ -19,7 +20,6 @@ use crate::netstream::NetStreamWriter;
use crate::protocol_control_messages::ProtocolControlMessagesWriter;
use crate::user_control_messages::EventMessagesWriter;
use crate::{handshake, PublishProducer};

pub struct Session<S: AsyncReadWrite> {
/// When you connect via rtmp, you specify the app name in the url
/// For example: rtmp://localhost:1935/live/xyz
Expand Down Expand Up @@ -131,7 +131,10 @@ impl<S: AsyncReadWrite> Session<S> {
let mut bytes_len = 0;

while bytes_len < handshake::RTMP_HANDSHAKE_SIZE {
let buf = tokio::time::timeout(Duration::from_millis(2500), self.io.read())
let buf = self
.io
.read()
.with_timeout(Duration::from_millis(2500))
.await
.map_err(|_| SessionError::BytesIO(BytesIOError::ClientClosed))??;
bytes_len += buf.len();
Expand Down Expand Up @@ -172,7 +175,10 @@ impl<S: AsyncReadWrite> Session<S> {
if self.skip_read {
self.skip_read = false;
} else {
let data = tokio::time::timeout(Duration::from_millis(2500), self.io.read())
let data = self
.io
.read()
.with_timeout(Duration::from_millis(2500))
.await
.map_err(|_| SessionError::BytesIO(BytesIOError::ClientClosed))??;
self.chunk_decoder.extend_data(&data[..]);
Expand Down Expand Up @@ -250,7 +256,7 @@ impl<S: AsyncReadWrite> Session<S> {
};

if matches!(
tokio::time::timeout(Duration::from_secs(2), self.data_producer.send(data)).await,
self.data_producer.send(data).with_timeout(Duration::from_secs(2)).await,
Err(_) | Ok(Err(_))
) {
tracing::debug!("Publisher dropped");
Expand Down Expand Up @@ -504,7 +510,9 @@ impl<S: AsyncReadWrite> Session<S> {
/// This is to avoid writing empty bytes to the underlying connection.
async fn write_data(&mut self, data: Bytes) -> Result<(), SessionError> {
if !data.is_empty() {
tokio::time::timeout(Duration::from_secs(2), self.io.write(data))
self.io
.write(data)
.with_timeout(Duration::from_secs(2))
.await
.map_err(|_| SessionError::BytesIO(BytesIOError::ClientClosed))??;
}
Expand Down
25 changes: 19 additions & 6 deletions crates/rtmp/src/tests/rtmp.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::path::PathBuf;
use std::time::Duration;

use scuffle_future_ext::FutureExt;
use tokio::process::Command;
use tokio::sync::mpsc;

Expand Down Expand Up @@ -34,7 +35,9 @@ async fn test_basic_rtmp_clean() {
.spawn()
.expect("failed to execute ffmpeg");

let (ffmpeg_stream, _) = tokio::time::timeout(Duration::from_millis(1000), listener.accept())
let (ffmpeg_stream, _) = listener
.accept()
.with_timeout(Duration::from_millis(1000))
.await
.expect("timedout")
.expect("failed to accept");
Expand All @@ -55,7 +58,9 @@ async fn test_basic_rtmp_clean() {
)
};

let event = tokio::time::timeout(Duration::from_millis(1000), ffmpeg_event_reciever.recv())
let event = ffmpeg_event_reciever
.recv()
.with_timeout(Duration::from_millis(1000))
.await
.expect("timedout")
.expect("failed to recv event");
Expand All @@ -70,7 +75,9 @@ async fn test_basic_rtmp_clean() {
let mut got_audio = false;
let mut got_metadata = false;

while let Some(data) = tokio::time::timeout(Duration::from_millis(1000), ffmpeg_data_reciever.recv())
while let Some(data) = ffmpeg_data_reciever
.recv()
.with_timeout(Duration::from_millis(1000))
.await
.expect("timedout")
{
Expand Down Expand Up @@ -119,7 +126,9 @@ async fn test_basic_rtmp_unclean() {
.spawn()
.expect("failed to execute ffmpeg");

let (ffmpeg_stream, _) = tokio::time::timeout(Duration::from_millis(1000), listener.accept())
let (ffmpeg_stream, _) = listener
.accept()
.with_timeout(Duration::from_millis(1000))
.await
.expect("timedout")
.expect("failed to accept");
Expand All @@ -140,7 +149,9 @@ async fn test_basic_rtmp_unclean() {
)
};

let event = tokio::time::timeout(Duration::from_millis(1000), ffmpeg_event_reciever.recv())
let event = ffmpeg_event_reciever
.recv()
.with_timeout(Duration::from_millis(1000))
.await
.expect("timedout")
.expect("failed to recv event");
Expand All @@ -155,7 +166,9 @@ async fn test_basic_rtmp_unclean() {
let mut got_audio = false;
let mut got_metadata = false;

while let Some(data) = tokio::time::timeout(Duration::from_millis(1000), ffmpeg_data_reciever.recv())
while let Some(data) = ffmpeg_data_reciever
.recv()
.with_timeout(Duration::from_millis(1000))
.await
.expect("timedout")
{
Expand Down
1 change: 1 addition & 0 deletions crates/signal/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ scuffle-workspace-hack.workspace = true
tokio = { version = "1.41.1", features = ["macros", "rt", "time"] }
libc = "0.2"
futures = "0.3"
scuffle-future-ext = { path = "../future-ext" }

[features]
bootstrap = ["scuffle-bootstrap", "scuffle-context", "anyhow", "tokio/macros"]
14 changes: 7 additions & 7 deletions crates/signal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ impl std::future::Future for SignalHandler {

#[cfg(test)]
mod tests {
use std::time::Duration;

use scuffle_future_ext::FutureExt;

use super::*;

fn raise_signal(kind: SignalKind) {
Expand All @@ -127,23 +131,19 @@ mod tests {

raise_signal(SignalKind::user_defined1());

let recv = tokio::time::timeout(tokio::time::Duration::from_millis(5), &mut handler)
.await
.unwrap();
let recv = (&mut handler).with_timeout(Duration::from_millis(5)).await.unwrap();

assert_eq!(recv, SignalKind::user_defined1(), "expected SIGUSR1");

// We already received the signal, so polling again should return Poll::Pending
let recv = tokio::time::timeout(tokio::time::Duration::from_millis(5), &mut handler).await;
let recv = (&mut handler).with_timeout(Duration::from_millis(5)).await;

assert!(recv.is_err(), "expected timeout");

raise_signal(SignalKind::user_defined2());

// We should be able to receive the signal again
let recv = tokio::time::timeout(tokio::time::Duration::from_millis(5), &mut handler)
.await
.unwrap();
let recv = (&mut handler).with_timeout(Duration::from_millis(5)).await.unwrap();

assert_eq!(recv, SignalKind::user_defined2(), "expected SIGUSR2");
}
Expand Down

0 comments on commit 8b5bc9b

Please sign in to comment.