diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 747635056..aa070f3ac 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -83,7 +83,7 @@ h2 = { version = "0.4", optional = true } hyper = { version = "1.0", features = ["full"], optional = true } hyper-util = { version = "0.1", features = ["full"] } hyper-timeout = { version = "0.5", optional = true } -tokio-stream = "0.1" +tokio-stream = { version = "0.1", features = ["net"] } tower = { version = "0.4.7", default-features = false, features = [ "balance", "buffer", diff --git a/tonic/benches/decode.rs b/tonic/benches/decode.rs index 5c7cd0159..ca9984c51 100644 --- a/tonic/benches/decode.rs +++ b/tonic/benches/decode.rs @@ -1,11 +1,13 @@ -use bencher::{benchmark_group, benchmark_main, Bencher}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use http_body::Body; use std::{ fmt::{Error, Formatter}, pin::Pin, task::{Context, Poll}, }; + +use bencher::{benchmark_group, benchmark_main, Bencher}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use http_body::{Body, Frame, SizeHint}; + use tonic::{codec::DecodeBuf, codec::Decoder, Status, Streaming}; macro_rules! bench { @@ -58,23 +60,24 @@ impl Body for MockBody { type Data = Bytes; type Error = Status; - fn poll_data( - mut self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { if self.data.has_remaining() { let split = std::cmp::min(self.chunk_size, self.data.remaining()); - Poll::Ready(Some(Ok(self.data.split_to(split)))) + Poll::Ready(Some(Ok(Frame::data(self.data.split_to(split))))) } else { Poll::Ready(None) } } - fn poll_trailers( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) + fn is_end_stream(&self) -> bool { + !self.data.is_empty() + } + + fn size_hint(&self) -> SizeHint { + SizeHint::with_exact(self.data.len() as u64) } } diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index cb88a0649..4ba5318de 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -4,6 +4,7 @@ use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; use bytes::{Buf, BufMut, BytesMut}; use http::StatusCode; use http_body::Body; +use http_body_util::BodyExt; use std::{ fmt, future, pin::Pin, @@ -122,7 +123,9 @@ impl Streaming { decoder: Box::new(decoder), inner: StreamingInner { body: body - .map_data(|mut buf| buf.copy_to_bytes(buf.remaining())) + .map_frame(|mut frame| { + frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining())) + }) .map_err(|err| Status::map_error(err.into())) .boxed_unsync(), state: State::ReadHeader, @@ -231,7 +234,7 @@ impl StreamingInner { // Returns Some(()) if data was found or None if the loop in `poll_next` should break fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll, Status>> { - let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) { + let chunk = match ready!(Pin::new(&mut self.body).poll_frame(cx)) { Some(Ok(d)) => Some(d), Some(Err(status)) => { if self.direction == Direction::Request && status.code() == Code::Cancelled { diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index fdaa450ff..5586cc8de 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -337,4 +337,3 @@ where self.state.is_end_stream } } - diff --git a/tonic/src/codec/prost.rs b/tonic/src/codec/prost.rs index 74e307352..5b85ba24d 100644 --- a/tonic/src/codec/prost.rs +++ b/tonic/src/codec/prost.rs @@ -328,4 +328,3 @@ mod tests { } } } - diff --git a/tonic/src/extensions.rs b/tonic/src/extensions.rs index 896a9f873..f74ec8910 100644 --- a/tonic/src/extensions.rs +++ b/tonic/src/extensions.rs @@ -95,4 +95,3 @@ impl GrpcMethod { self.method } } - diff --git a/tonic/src/status.rs b/tonic/src/status.rs index d0ba18a34..2e691df0c 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -1003,4 +1003,3 @@ mod tests { assert_eq!(status.details(), DETAILS); } } - diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index ac23de98e..4636b0fee 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -7,9 +7,9 @@ use crate::transport::service::TlsConnector; use crate::transport::{service::SharedExec, Error, Executor}; use bytes::Bytes; use http::{uri::Uri, HeaderValue}; +use hyper::rt; use std::{fmt, future::Future, pin::Pin, str::FromStr, time::Duration}; -use tower::make::MakeConnection; -// use crate::transport::E +use tower_service::Service; /// Channel builder. /// @@ -359,8 +359,8 @@ impl Endpoint { /// The [`connect_timeout`](Endpoint::connect_timeout) will still be applied. pub async fn connect_with_connector(&self, connector: C) -> Result where - C: MakeConnection + Send + 'static, - C::Connection: Unpin + Send + 'static, + C: Service + Send + 'static, + C::Response: rt::Read + rt::Write + Send + Unpin + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, { @@ -384,8 +384,8 @@ impl Endpoint { /// uses a Unix socket transport. pub fn connect_with_connector_lazy(&self, connector: C) -> Channel where - C: MakeConnection + Send + 'static, - C::Connection: Unpin + Send + 'static, + C: Service + Send + 'static, + C::Response: rt::Read + rt::Write + Send + Unpin + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, { diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index b1bbc6046..420fd1f3c 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -21,12 +21,10 @@ use std::{ pin::Pin, task::{ready, Context, Poll}, }; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::mpsc::{channel, Sender}, -}; +use tokio::sync::mpsc::{channel, Sender}; -use axum::{extract::Request, response::Response, body::Body}; +use axum::{body::Body, extract::Request, response::Response}; +use hyper::rt; use tower::balance::p2c::Balance; use tower::{ buffer::{self, Buffer}, @@ -149,7 +147,7 @@ impl Channel { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + HyperConnection + Unpin + Send + 'static, { let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); let executor = endpoint.executor.clone(); @@ -166,7 +164,7 @@ impl Channel { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + HyperConnection + Unpin + Send + 'static, { let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); let executor = endpoint.executor.clone(); diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index a0435c797..8c892e5bc 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -107,8 +107,9 @@ pub use self::service::grpc_timeout::TimeoutExpired; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::tls::Certificate; -pub use axum::{body::BoxBody as AxumBoxBody, Router as AxumRouter}; -pub use hyper::{Body, Uri}; +pub use axum::{body::Body as AxumBoxBody, Router as AxumRouter}; +pub use hyper::body::Body; +pub use hyper::Uri; pub(crate) use self::service::executor::Executor; diff --git a/tonic/src/transport/server/conn.rs b/tonic/src/transport/server/conn.rs index 37bcc561b..11fcc4fd0 100644 --- a/tonic/src/transport/server/conn.rs +++ b/tonic/src/transport/server/conn.rs @@ -1,13 +1,14 @@ use std::net::SocketAddr; -use tokio::net::TcpStream; - -#[cfg(feature = "tls")] -use crate::transport::Certificate; #[cfg(feature = "tls")] use std::sync::Arc; + +use tokio::net::TcpStream; #[cfg(feature = "tls")] use tokio_rustls::server::TlsStream; +#[cfg(feature = "tls")] +use crate::transport::Certificate; + /// Trait that connected IO resources implement and use to produce info about the connection. /// /// The goal for this trait is to allow users to implement diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index b78bae881..768bd29f8 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -13,7 +13,7 @@ pub use super::service::Routes; pub use super::service::RoutesBuilder; pub use conn::{Connected, TcpConnectInfo}; -use hyper_util::rt::TokioExecutor; +use hyper_util::rt::{TokioExecutor, TokioIo}; #[cfg(feature = "tls")] pub use tls::ServerTlsConfig; @@ -36,12 +36,12 @@ use crate::transport::Error; use self::recover_error::RecoverError; use super::service::{GrpcTimeout, ServerIo}; -use crate::body::BoxBody; use crate::server::NamedService; +use axum::extract::Request; +use axum::response::Response; use bytes::Bytes; -use http::{Request, Response}; -use http_body::Body as _; -use hyper::Body; +use http_body_util::BodyExt; +use hyper::body::Body; use pin_project::pin_project; use std::{ convert::Infallible, @@ -55,7 +55,6 @@ use std::{ time::Duration, }; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::TcpStream; use tokio_stream::Stream; use tower::{ layer::util::{Identity, Stack}, @@ -65,9 +64,8 @@ use tower::{ Service, ServiceBuilder, }; -type BoxHttpBody = http_body_util::combinators::UnsyncBoxBody; -type BoxService = tower::util::BoxService, Response, crate::Error>; -type TraceInterceptor = Arc) -> tracing::Span + Send + Sync + 'static>; +type BoxService = tower::util::BoxService; +type TraceInterceptor = Arc) -> tracing::Span + Send + Sync + 'static>; const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS: u64 = 20; @@ -361,7 +359,7 @@ impl Server { /// route around different services. pub fn add_service(&mut self, svc: S) -> Router where - S: Service, Response = Response, Error = Infallible> + S: Service + NamedService + Clone + Send @@ -382,7 +380,7 @@ impl Server { /// As a result, one cannot use this to toggle between two identically named implementations. pub fn add_optional_service(&mut self, svc: Option) -> Router where - S: Service, Response = Response, Error = Infallible> + S: Service + NamedService + Clone + Send @@ -496,15 +494,15 @@ impl Server { ) -> Result<(), super::Error> where L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: Service> + Clone + Send + 'static, + <>::Service as Service>::Future: Send + 'static, + <>::Service as Service>::Error: Into + Send, I: Stream>, IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IO::ConnectInfo: Clone + Send + Sync + 'static, IE: Into, F: Future, - ResBody: http_body::Body + Send + 'static, + ResBody: Body + Send + 'static, ResBody::Error: Into, { let trace_interceptor = self.trace_interceptor.clone(); @@ -525,9 +523,7 @@ impl Server { let svc = self.service_builder.service(svc); - let tcp = incoming::tcp_incoming(incoming, self); - let incoming = TcpStream::accept::from_stream::<_, _, crate::Error>(tcp); - + let incoming = incoming::tcp_incoming(incoming, self); let svc = MakeSvc { inner: svc, concurrency_limit, @@ -536,7 +532,9 @@ impl Server { _io: PhantomData, }; - let server = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) + let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); + + builder .http2() .initial_connection_window_size(init_connection_window_size) .initial_stream_window_size(init_stream_window_size) @@ -548,15 +546,8 @@ impl Server { //.max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams) .max_frame_size(max_frame_size); - if let Some(signal) = signal { - server - .serve(svc) - .with_graceful_shutdown(signal) - .await - .map_err(super::Error::from_source)? - } else { - server.serve(svc).await.map_err(super::Error::from_source)?; - } + let io = TokioIo::new(incoming); + let connection = builder.serve_connection(io, svc); Ok(()) } @@ -572,7 +563,7 @@ impl Router { /// Add a new service to this router. pub fn add_service(mut self, svc: S) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service + NamedService + Clone + Send @@ -591,7 +582,7 @@ impl Router { #[allow(clippy::type_complexity)] pub fn add_optional_service(mut self, svc: Option) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service + NamedService + Clone + Send @@ -617,10 +608,10 @@ impl Router { pub async fn serve(self, addr: SocketAddr) -> Result<(), super::Error> where L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, - ResBody: http_body::Body + Send + 'static, + L::Service: Service> + Clone + Send + 'static, + <>::Service as Service>::Future: Send + 'static, + <>::Service as Service>::Error: Into + Send, + ResBody: Body + Send + 'static, ResBody::Error: Into, { let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive) @@ -647,10 +638,10 @@ impl Router { ) -> Result<(), super::Error> where L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, - ResBody: http_body::Body + Send + 'static, + L::Service: Service> + Clone + Send + 'static, + <>::Service as Service>::Future: Send + 'static, + <>::Service as Service>::Error: Into + Send, + ResBody: Body + Send + 'static, ResBody::Error: Into, { let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive) @@ -676,10 +667,10 @@ impl Router { IO::ConnectInfo: Clone + Send + Sync + 'static, IE: Into, L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, - ResBody: http_body::Body + Send + 'static, + L::Service: Service> + Clone + Send + 'static, + <>::Service as Service>::Future: Send + 'static, + <>::Service as Service>::Error: Into + Send, + ResBody: Body + Send + 'static, ResBody::Error: Into, { self.server @@ -711,10 +702,10 @@ impl Router { IE: Into, F: Future, L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, - ResBody: http_body::Body + Send + 'static, + L::Service: Service> + Clone + Send + 'static, + <>::Service as Service>::Future: Send + 'static, + <>::Service as Service>::Error: Into + Send, + ResBody: Body + Send + 'static, ResBody::Error: Into, { self.server @@ -726,10 +717,10 @@ impl Router { pub fn into_service(self) -> L::Service where L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, - ResBody: http_body::Body + Send + 'static, + L::Service: Service> + Clone + Send + 'static, + <>::Service as Service>::Future: Send + 'static, + <>::Service as Service>::Error: Into + Send, + ResBody: Body + Send + 'static, ResBody::Error: Into, { self.server.service_builder.service(self.routes.prepare()) @@ -747,14 +738,14 @@ struct Svc { trace_interceptor: Option, } -impl Service> for Svc +impl Service for Svc where - S: Service, Response = Response>, + S: Service>, S::Error: Into, - ResBody: http_body::Body + Send + 'static, + ResBody: Body + Send + 'static, ResBody::Error: Into, { - type Response = Response; + type Response = Response; type Error = crate::Error; type Future = SvcFuture; @@ -762,7 +753,7 @@ where self.inner.poll_ready(cx).map_err(Into::into) } - fn call(&mut self, mut req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { let span = if let Some(trace_interceptor) = &self.trace_interceptor { let (parts, body) = req.into_parts(); let bodyless_request = Request::from_parts(parts, ()); @@ -795,10 +786,10 @@ impl Future for SvcFuture where F: Future, E>>, E: Into, - ResBody: http_body::Body + Send + 'static, + ResBody: Body + Send + 'static, ResBody::Error: Into, { - type Output = Result, crate::Error>; + type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -827,10 +818,10 @@ struct MakeSvc { impl Service<&ServerIo> for MakeSvc where IO: Connected, - S: Service, Response = Response> + Clone + Send + 'static, + S: Service> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, - ResBody: http_body::Body + Send + 'static, + ResBody: Body + Send + 'static, ResBody::Error: Into, { type Response = BoxService; @@ -857,7 +848,7 @@ where let svc = ServiceBuilder::new() .layer(BoxService::layer()) - .map_request(move |mut request: Request| { + .map_request(move |mut request: Request| { match &conn_info { tower::util::Either::A(inner) => { request.extensions_mut().insert(inner.clone()); @@ -888,4 +879,3 @@ where future::ready(Ok(svc)) } } - diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index 170000c37..389fc73a8 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -2,12 +2,11 @@ use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgen use crate::transport::{BoxFuture, Endpoint}; use http::Uri; use hyper::client::conn::http2::Builder; -use hyper_util::client::legacy::connect::{Connect as HyperConnect, Connection as HyperConnection}; +use hyper::rt; use std::{ fmt, task::{Context, Poll}, }; -use tokio::io::{AsyncRead, AsyncWrite}; use tower::load::Load; use tower::{ layer::Layer, @@ -29,7 +28,7 @@ impl Connection { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { let mut settings = Builder::new(endpoint.executor) .initial_stream_window_size(endpoint.init_stream_window_size) @@ -61,9 +60,7 @@ impl Connection { .option_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d))) .into_inner(); - let connector = HyperConnect::new(connector, settings); let conn = Reconnect::new(connector, endpoint.uri.clone(), is_lazy); - let inner = stack.layer(conn); Self { @@ -76,7 +73,7 @@ impl Connection { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { Self::new(connector, endpoint, false).ready_oneshot().await } @@ -86,7 +83,7 @@ impl Connection { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { Self::new(connector, endpoint, true) } @@ -119,4 +116,3 @@ impl fmt::Debug for Connection { f.debug_struct("Connection").finish() } } - diff --git a/tonic/src/transport/service/connector.rs b/tonic/src/transport/service/connector.rs index a5e0d9eb9..7a4a0eab2 100644 --- a/tonic/src/transport/service/connector.rs +++ b/tonic/src/transport/service/connector.rs @@ -1,12 +1,15 @@ +use std::fmt; +use std::task::{Context, Poll}; + +use http::Uri; +use hyper::rt; +use hyper_util::rt::TokioIo; +use tower_service::Service; + use super::super::BoxFuture; use super::io::BoxedIo; #[cfg(feature = "tls")] use super::tls::TlsConnector; -use http::Uri; -use std::fmt; -use std::task::{Context, Poll}; -use tower::make::MakeConnection; -use tower_service::Service; pub(crate) struct Connector { inner: C, @@ -47,8 +50,8 @@ impl Connector { impl Service for Connector where - C: MakeConnection, - C::Connection: Unpin + Send + 'static, + C: Service, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, { @@ -57,7 +60,7 @@ where type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - MakeConnection::poll_ready(&mut self.inner, cx).map_err(Into::into) + self.inner.poll_ready(cx).map_err(Into::into) } fn call(&mut self, uri: Uri) -> Self::Future { @@ -69,7 +72,7 @@ where #[cfg(feature = "tls")] let is_https = uri.scheme_str() == Some("https"); - let connect = self.inner.make_connection(uri); + let connect = self.inner.call(uri); Box::pin(async move { let io = connect.await?; @@ -77,12 +80,12 @@ where #[cfg(feature = "tls")] { if let Some(tls) = tls { - if is_https { - let conn = tls.connect(io).await?; - return Ok(BoxedIo::new(conn)); + return if is_https { + let io = tls.connect(TokioIo::new(io)).await?; + Ok(io) } else { - return Ok(BoxedIo::new(io)); - } + Ok(BoxedIo::new(io)) + }; } else if is_https { return Err(HttpsUriWithoutTlsSupport(()).into()); } diff --git a/tonic/src/transport/service/io.rs b/tonic/src/transport/service/io.rs index e5e75287a..9be7fa343 100644 --- a/tonic/src/transport/service/io.rs +++ b/tonic/src/transport/service/io.rs @@ -1,19 +1,23 @@ -use crate::transport::server::Connected; -use hyper_util::client::legacy::connect::{Connected as HyperConnected, Connection}; +use hyper::rt; use std::io; use std::io::IoSlice; use std::pin::Pin; use std::task::{Context, Poll}; + +use hyper_util::client::legacy::connect::{Connected as HyperConnected, Connection}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; #[cfg(feature = "tls")] use tokio_rustls::server::TlsStream; +use tower::util::Either; + +use crate::transport::server::Connected; pub(in crate::transport) trait Io: - AsyncRead + AsyncWrite + Send + 'static + rt::Read + rt::Write + Send + 'static { } -impl Io for T where T: AsyncRead + AsyncWrite + Send + 'static {} +impl Io for T where T: rt::Read + rt::Write + Send + 'static {} pub(crate) struct BoxedIo(Pin>); @@ -40,17 +44,17 @@ impl Connected for BoxedIo { #[derive(Copy, Clone)] pub(crate) struct NoneConnectInfo; -impl AsyncRead for BoxedIo { +impl rt::Read for BoxedIo { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, + buf: rt::ReadBufCursor<'_>, ) -> Poll> { Pin::new(&mut self.0).poll_read(cx, buf) } } -impl AsyncWrite for BoxedIo { +impl rt::Write for BoxedIo { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -67,6 +71,10 @@ impl AsyncWrite for BoxedIo { Pin::new(&mut self.0).poll_shutdown(cx) } + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } + fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -74,10 +82,6 @@ impl AsyncWrite for BoxedIo { ) -> Poll> { Pin::new(&mut self.0).poll_write_vectored(cx, bufs) } - - fn is_write_vectored(&self) -> bool { - self.0.is_write_vectored() - } } pub(crate) enum ServerIo { @@ -86,8 +90,6 @@ pub(crate) enum ServerIo { TlsIo(Box>), } -use tower::util::Either; - #[cfg(feature = "tls")] type ServerIoConnectInfo = Either<::ConnectInfo, as Connected>::ConnectInfo>; diff --git a/tonic/src/transport/service/router.rs b/tonic/src/transport/service/router.rs index ab3d43978..89f2fc8ff 100644 --- a/tonic/src/transport/service/router.rs +++ b/tonic/src/transport/service/router.rs @@ -1,4 +1,4 @@ -use crate::{body::boxed, server::NamedService}; +use crate::server::NamedService; use axum::{extract::Request, response::Response}; use pin_project::pin_project; use std::{ @@ -8,7 +8,6 @@ use std::{ pin::Pin, task::{ready, Context, Poll}, }; -use tower::ServiceExt; use tower_service::Service; /// A [`Service`] router. @@ -72,7 +71,6 @@ impl Routes { S::Future: Send + 'static, S::Error: Into + Send, { - let svc = svc.map_response(|res| res.map(axum::body::boxed)); self.router = self .router .route_service(&format!("/{}/*rest", S::NAME), svc); @@ -128,9 +126,8 @@ impl Future for RoutesFuture { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match ready!(self.project().0.poll(cx)) { - Ok(res) => Ok(res.map(boxed)).into(), + Ok(res) => Ok(res).into(), Err(err) => match err {}, } } } -