Skip to content

Commit

Permalink
feat: Expose GrpcTimeout middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
tottoto committed Jan 31, 2025
1 parent 78be69e commit a97edeb
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 13 deletions.
6 changes: 4 additions & 2 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,22 @@ tls-aws-lc = ["_tls-any", "tokio-rustls/aws-lc-rs"]
tls-native-roots = ["_tls-any", "channel", "dep:rustls-native-certs"]
tls-webpki-roots = ["_tls-any","channel", "dep:webpki-roots"]
router = ["dep:axum", "dep:tower", "tower?/util"]
timeout = ["dep:tokio", "tokio?/time"]
server = [
"timeout",
"dep:h2",
"dep:hyper", "hyper?/server",
"dep:hyper-util", "hyper-util?/service", "hyper-util?/server-auto",
"dep:socket2",
"dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time",
"dep:tokio", "tokio?/macros", "tokio?/net",
"tokio-stream/net",
"dep:tower", "tower?/util", "tower?/limit",
]
channel = [
"timeout",
"dep:hyper", "hyper?/client",
"dep:hyper-util", "hyper-util?/client-legacy",
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/limit", "tower?/util",
"dep:tokio", "tokio?/time",
"dep:hyper-timeout",
]
transport = ["server", "channel"]
Expand Down
2 changes: 2 additions & 0 deletions tonic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
//! - `server`: Enables just the full featured server portion of the `transport` feature.
//! - `channel`: Enables just the full featured channel portion of the `transport` feature.
//! - `router`: Enables the [`axum`] based service router. Enabled by default.
//! - `timeout`: Enables timeout related feature including `GrpcTimeout` middleware. Enabled
//! by default.
//! - `codegen`: Enables all the required exports and optional dependencies required
//! for [`tonic-build`]. Enabled by default.
//! - `tls-ring`: Enables the [`rustls`] based TLS options for the `transport` feature using
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,50 @@
//! Middleware which implements gRPC timeout.
use crate::{metadata::GRPC_TIMEOUT_HEADER, TimeoutExpired};
use http::{HeaderMap, HeaderValue, Request};
use pin_project::pin_project;
use std::{
fmt,
future::Future,
pin::Pin,
task::{ready, Context, Poll},
time::Duration,
};
use tokio::time::Sleep;
use tower_layer::Layer;
use tower_service::Service;

/// Layer which applies the [`GrpcTimeout`] middleware.
#[derive(Debug, Clone)]
pub(crate) struct GrpcTimeout<S> {
pub struct GrpcTimeoutLayer {
server_timeout: Option<Duration>,
}

impl<S> Layer<S> for GrpcTimeoutLayer {
type Service = GrpcTimeout<S>;

fn layer(&self, inner: S) -> Self::Service {
GrpcTimeout::new(inner, self.server_timeout)
}
}

impl GrpcTimeoutLayer {
/// Create a new `GrpcTimeoutLayer`.
pub fn new(server_timeout: Option<Duration>) -> Self {
Self { server_timeout }
}
}

/// Middleware which implements gRPC timeout.
#[derive(Debug, Clone)]
pub struct GrpcTimeout<S> {
inner: S,
server_timeout: Option<Duration>,
}

impl<S> GrpcTimeout<S> {
pub(crate) fn new(inner: S, server_timeout: Option<Duration>) -> Self {
/// Create a new [`GrpcTimeout`] middleware.
pub fn new(inner: S, server_timeout: Option<Duration>) -> Self {
Self {
inner,
server_timeout,
Expand Down Expand Up @@ -62,14 +89,21 @@ where
}
}

/// Response future for [`GrpcTimeout`].
#[pin_project]
pub(crate) struct ResponseFuture<F> {
pub struct ResponseFuture<F> {
#[pin]
inner: F,
#[pin]
sleep: Option<Sleep>,
}

impl<F> fmt::Debug for ResponseFuture<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ResponseFuture").finish()
}
}

impl<F, Res, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Res, E>>,
Expand Down
5 changes: 5 additions & 0 deletions tonic/src/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@ pub use axum::{body::Body as AxumBody, Router as AxumRouter};

pub mod recover_error;
pub use self::recover_error::{RecoverError, RecoverErrorLayer};

#[cfg(feature = "timeout")]
pub mod grpc_timeout;
#[cfg(feature = "timeout")]
pub use self::grpc_timeout::{GrpcTimeout, GrpcTimeoutLayer};
5 changes: 3 additions & 2 deletions tonic/src/transport/channel/service/connection.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use super::{AddOrigin, Reconnect, SharedExec, UserAgent};
use crate::{
body::Body,
transport::{channel::BoxFuture, service::GrpcTimeout, Endpoint},
service::GrpcTimeoutLayer,
transport::{channel::BoxFuture, Endpoint},
};
use http::{Request, Response, Uri};
use hyper::rt;
Expand Down Expand Up @@ -62,7 +63,7 @@ impl Connection {
AddOrigin::new(s, origin)
})
.layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone()))
.layer_fn(|s| GrpcTimeout::new(s, endpoint.timeout))
.layer(GrpcTimeoutLayer::new(endpoint.timeout))
.option_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new))
.option_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d)))
.into_inner();
Expand Down
5 changes: 2 additions & 3 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,8 @@ pub use incoming::TcpIncoming;
use crate::transport::Error;

use self::service::{ConnectInfoLayer, ServerIo};
use super::service::GrpcTimeout;
use crate::body::Body;
use crate::service::RecoverErrorLayer;
use crate::service::{GrpcTimeoutLayer, RecoverErrorLayer};
use bytes::Bytes;
use http::{Request, Response};
use http_body_util::BodyExt;
Expand Down Expand Up @@ -1090,7 +1089,7 @@ where
let svc = ServiceBuilder::new()
.layer(RecoverErrorLayer::new())
.option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
.layer_fn(|s| GrpcTimeout::new(s, timeout))
.layer(GrpcTimeoutLayer::new(timeout))
.service(svc);

let svc = ServiceBuilder::new()
Expand Down
3 changes: 0 additions & 3 deletions tonic/src/transport/service/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
pub(crate) mod grpc_timeout;
#[cfg(feature = "_tls-any")]
pub(crate) mod tls;

pub(crate) use self::grpc_timeout::GrpcTimeout;

0 comments on commit a97edeb

Please sign in to comment.