diff --git a/examples/grpc/helloworld_compressed/Cargo.toml b/examples/grpc/helloworld_compressed/Cargo.toml new file mode 100644 index 0000000000..d1723b85f5 --- /dev/null +++ b/examples/grpc/helloworld_compressed/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "example-grpc-helloworld-compressed" +version.workspace = true +edition.workspace = true +publish.workspace = true + +[dependencies] +poem.workspace = true +poem-grpc = { workspace = true, features = [ + "gzip", + "deflate", + "brotli", + "zstd", +] } +prost.workspace = true +tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } + +[build-dependencies] +poem-grpc-build.workspace = true + +[[bin]] +name = "grpc-helloworld-client" +path = "src/client.rs" diff --git a/examples/grpc/helloworld_compressed/build.rs b/examples/grpc/helloworld_compressed/build.rs new file mode 100644 index 0000000000..a388ebfa5d --- /dev/null +++ b/examples/grpc/helloworld_compressed/build.rs @@ -0,0 +1,7 @@ +use std::io::Result; + +use poem_grpc_build::compile_protos; + +fn main() -> Result<()> { + compile_protos(&["./proto/helloworld.proto"], &["./proto"]) +} diff --git a/examples/grpc/helloworld_compressed/proto/helloworld.proto b/examples/grpc/helloworld_compressed/proto/helloworld.proto new file mode 100644 index 0000000000..8de5d08ef4 --- /dev/null +++ b/examples/grpc/helloworld_compressed/proto/helloworld.proto @@ -0,0 +1,37 @@ +// Copyright 2015 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +option java_multiple_files = true; +option java_package = "io.grpc.examples.helloworld"; +option java_outer_classname = "HelloWorldProto"; + +package helloworld; + +// The greeting service definition. +service Greeter { + // Sends a greeting + rpc SayHello (HelloRequest) returns (HelloReply) {} +} + +// The request message containing the user's name. +message HelloRequest { + string name = 1; +} + +// The response message containing the greetings +message HelloReply { + string message = 1; +} \ No newline at end of file diff --git a/examples/grpc/helloworld_compressed/src/client.rs b/examples/grpc/helloworld_compressed/src/client.rs new file mode 100644 index 0000000000..7f9abeaea9 --- /dev/null +++ b/examples/grpc/helloworld_compressed/src/client.rs @@ -0,0 +1,22 @@ +use poem_grpc::{ClientConfig, CompressionEncoding, Request}; + +poem_grpc::include_proto!("helloworld"); + +#[tokio::main] +async fn main() -> Result<(), Box> { + let mut client = GreeterClient::new( + ClientConfig::builder() + .uri("http://localhost:3000") + .build() + .unwrap(), + ); + client.set_send_compressed(CompressionEncoding::GZIP); + client.set_accept_compressed([CompressionEncoding::GZIP]); + + let request = Request::new(HelloRequest { + name: "Poem".into(), + }); + let response = client.say_hello(request).await?; + println!("RESPONSE={response:?}"); + Ok(()) +} diff --git a/examples/grpc/helloworld_compressed/src/main.rs b/examples/grpc/helloworld_compressed/src/main.rs new file mode 100644 index 0000000000..f6f4c918a4 --- /dev/null +++ b/examples/grpc/helloworld_compressed/src/main.rs @@ -0,0 +1,35 @@ +use poem::{listener::TcpListener, Server}; +use poem_grpc::{CompressionEncoding, Request, Response, RouteGrpc, Status}; + +poem_grpc::include_proto!("helloworld"); + +struct GreeterService; + +impl Greeter for GreeterService { + async fn say_hello( + &self, + request: Request, + ) -> Result, Status> { + let reply = HelloReply { + message: format!("Hello {}!", request.into_inner().name), + }; + Ok(Response::new(reply)) + } +} + +#[tokio::main] +async fn main() -> Result<(), std::io::Error> { + let route = RouteGrpc::new().add_service( + GreeterServer::new(GreeterService) + .send_compressed(CompressionEncoding::GZIP) + .accept_compressed([ + CompressionEncoding::GZIP, + CompressionEncoding::DEFLATE, + CompressionEncoding::BROTLI, + CompressionEncoding::ZSTD, + ]), + ); + Server::new(TcpListener::bind("0.0.0.0:3000")) + .run(route) + .await +} diff --git a/poem-grpc-build/src/client.rs b/poem-grpc-build/src/client.rs index a27fea099f..89554d16e0 100644 --- a/poem-grpc-build/src/client.rs +++ b/poem-grpc-build/src/client.rs @@ -117,6 +117,16 @@ pub(crate) fn generate(config: &GrpcConfig, service: &Service, buf: &mut String) self } + /// Set the compression encoding for sending + pub fn set_send_compressed(&mut self, encoding: #crate_name::CompressionEncoding) { + self.cli.set_send_compressed(encoding); + } + + /// Set the compression encodings for accepting + pub fn set_accept_compressed(&mut self, encodings: impl ::std::convert::Into<::std::sync::Arc<[#crate_name::CompressionEncoding]>>) { + self.cli.set_accept_compressed(encodings); + } + #( #[allow(dead_code)] #methods diff --git a/poem-grpc-build/src/server.rs b/poem-grpc-build/src/server.rs index 73e5cde804..fad11d8eba 100644 --- a/poem-grpc-build/src/server.rs +++ b/poem-grpc-build/src/server.rs @@ -99,8 +99,22 @@ pub(crate) fn generate(config: &GrpcConfig, service: &Service, buf: &mut String) } #[allow(unused_imports)] - #[derive(Clone)] - pub struct #server_ident(::std::sync::Arc); + pub struct #server_ident { + inner: ::std::sync::Arc, + send_compressd: ::std::option::Option<#crate_name::CompressionEncoding>, + accept_compressed: ::std::sync::Arc<[#crate_name::CompressionEncoding]>, + } + + impl ::std::clone::Clone for #server_ident { + #[inline] + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + send_compressd: self.send_compressd, + accept_compressed: self.accept_compressed.clone(), + } + } + } impl #crate_name::Service for #server_ident { const NAME: &'static str = #service_name; @@ -109,8 +123,29 @@ pub(crate) fn generate(config: &GrpcConfig, service: &Service, buf: &mut String) #[allow(dead_code)] impl #server_ident { + /// Create a new GRPC server pub fn new(service: T) -> Self { - Self(::std::sync::Arc::new(service)) + Self { + inner: ::std::sync::Arc::new(service), + send_compressd: ::std::option::Option::None, + accept_compressed: ::std::default::Default::default(), + } + } + + /// Set the compression encoding for sending + pub fn send_compressed(self, encoding: #crate_name::CompressionEncoding) -> Self { + Self { + send_compressd: Some(encoding), + ..self + } + } + + /// Set the compression encodings for accepting + pub fn accept_compressed(self, encodings: impl ::std::convert::Into<::std::sync::Arc<[#crate_name::CompressionEncoding]>>) -> Self { + Self { + accept_compressed: encodings.into(), + ..self + } } } @@ -191,7 +226,7 @@ fn generate_unary(codec_list: &[Path], method_info: MethodInfo) -> TokenStream { crate_name, codec_list, quote! { - #crate_name::server::GrpcServer::new(codec).unary(#proxy_service_ident(svc.clone()), req).await + #crate_name::server::GrpcServer::new(codec, server.send_compressd, &server.accept_compressed).unary(#proxy_service_ident(server.inner.clone()), req).await }, ); @@ -211,9 +246,9 @@ fn generate_unary(codec_list: &[Path], method_info: MethodInfo) -> TokenStream { } route = route.at(#path, ::poem::endpoint::make({ - let svc = self.0.clone(); + let server = self.clone(); move |req| { - let svc = svc.clone(); + let server = server.clone(); async move { #call } } })); @@ -235,7 +270,7 @@ fn generate_client_streaming(codec_list: &[Path], method_info: MethodInfo) -> To crate_name, codec_list, quote! { - #crate_name::server::GrpcServer::new(codec).client_streaming(#proxy_service_ident(svc.clone()), req).await + #crate_name::server::GrpcServer::new(codec, server.send_compressd, &server.accept_compressed).client_streaming(#proxy_service_ident(server.inner.clone()), req).await }, ); @@ -255,9 +290,9 @@ fn generate_client_streaming(codec_list: &[Path], method_info: MethodInfo) -> To } route = route.at(#path, ::poem::endpoint::make({ - let svc = self.0.clone(); + let server = self.clone(); move |req| { - let svc = svc.clone(); + let server = server.clone(); async move { #call } } })); @@ -279,7 +314,7 @@ fn generate_server_streaming(codec_list: &[Path], method_info: MethodInfo) -> To crate_name, codec_list, quote! { - #crate_name::server::GrpcServer::new(codec).server_streaming(#proxy_service_ident(svc.clone()), req).await + #crate_name::server::GrpcServer::new(codec, server.send_compressd, &server.accept_compressed).server_streaming(#proxy_service_ident(server.inner.clone()), req).await }, ); @@ -299,9 +334,9 @@ fn generate_server_streaming(codec_list: &[Path], method_info: MethodInfo) -> To } route = route.at(#path, ::poem::endpoint::make({ - let svc = self.0.clone(); + let server = self.clone(); move |req| { - let svc = svc.clone(); + let server = server.clone(); async move { #call } } })); @@ -323,7 +358,7 @@ fn generate_bidirectional_streaming(codec_list: &[Path], method_info: MethodInfo crate_name, codec_list, quote! { - #crate_name::server::GrpcServer::new(codec).bidirectional_streaming(#proxy_service_ident(svc.clone()), req).await + #crate_name::server::GrpcServer::new(codec, server.send_compressd, &server.accept_compressed).bidirectional_streaming(#proxy_service_ident(server.inner.clone()), req).await }, ); @@ -343,9 +378,9 @@ fn generate_bidirectional_streaming(codec_list: &[Path], method_info: MethodInfo } route = route.at(#path, ::poem::endpoint::make({ - let svc = self.0.clone(); + let server = self.clone(); move |req| { - let svc = svc.clone(); + let server = server.clone(); async move { #call } } })); diff --git a/poem-grpc/Cargo.toml b/poem-grpc/Cargo.toml index 0efd641ae7..deaa20dd6d 100644 --- a/poem-grpc/Cargo.toml +++ b/poem-grpc/Cargo.toml @@ -16,6 +16,11 @@ categories = ["network-programming", "asynchronous"] [features] default = [] json-codec = ["serde", "serde_json"] +gzip = ["async-compression/gzip"] +deflate = ["async-compression/deflate"] +brotli = ["async-compression/brotli"] +zstd = ["async-compression/zstd"] +example_generated = [] [dependencies] poem = { workspace = true, default-features = true } @@ -23,7 +28,6 @@ poem = { workspace = true, default-features = true } futures-util.workspace = true async-stream = "0.3.3" tokio = { workspace = true, features = ["io-util", "rt", "sync", "net"] } -flate2 = "1.0.24" itoa = "1.0.2" percent-encoding = "2.1.0" bytes.workspace = true @@ -43,6 +47,8 @@ http-body-util = "0.1.0" tokio-rustls.workspace = true tower-service = "0.3.2" webpki-roots = "0.26" +async-compression = { version = "0.4.0", optional = true, features = ["tokio"] } +sync_wrapper = { version = "1.0.0", features = ["futures"] } [build-dependencies] poem-grpc-build.workspace = true diff --git a/poem-grpc/build.rs b/poem-grpc/build.rs index 81dbd1a58c..661c6a4596 100644 --- a/poem-grpc/build.rs +++ b/poem-grpc/build.rs @@ -13,5 +13,12 @@ fn main() -> Result<()> { // for test poem_grpc_build::Config::new() .internal() - .compile(&["proto/test_harness.proto"], &["proto/"]) + .compile(&["proto/test_harness.proto"], &["proto/"])?; + + // example + poem_grpc_build::Config::new() + .internal() + .compile(&["src/example_generated/routeguide.proto"], &[] as &[&str])?; + + Ok(()) } diff --git a/poem-grpc/src/client.rs b/poem-grpc/src/client.rs index e42db405cc..07a17a814b 100644 --- a/poem-grpc/src/client.rs +++ b/poem-grpc/src/client.rs @@ -18,9 +18,10 @@ use rustls::ClientConfig as TlsClientConfig; use crate::{ codec::Codec, + compression::get_incoming_encodings, connector::HttpsConnector, encoding::{create_decode_response_body, create_encode_request_body}, - Code, Metadata, Request, Response, Status, Streaming, + Code, CompressionEncoding, Metadata, Request, Response, Status, Streaming, }; pub(crate) type BoxBody = http_body_util::combinators::BoxBody; @@ -155,6 +156,8 @@ impl ClientConfigBuilder { #[derive(Clone)] pub struct GrpcClient { ep: Arc + 'static>, + send_compressd: Option, + accept_compressed: Arc<[CompressionEncoding]>, } impl GrpcClient { @@ -162,6 +165,8 @@ impl GrpcClient { pub fn new(config: ClientConfig) -> Self { Self { ep: create_client_endpoint(config), + send_compressd: None, + accept_compressed: Default::default(), } } @@ -173,9 +178,19 @@ impl GrpcClient { { Self { ep: Arc::new(ToDynEndpoint(ep.map_to_response())), + send_compressd: None, + accept_compressed: Default::default(), } } + pub fn set_send_compressed(&mut self, encoding: CompressionEncoding) { + self.send_compressd = Some(encoding); + } + + pub fn set_accept_compressed(&mut self, encodings: impl Into>) { + self.accept_compressed = encodings.into(); + } + pub fn with(mut self, middleware: M) -> Self where M: Middleware + 'static>>, @@ -198,10 +213,12 @@ impl GrpcClient { message, extensions, } = request; - let mut http_request = create_http_request::(path, metadata, extensions); + let mut http_request = + create_http_request::(path, metadata, extensions, self.send_compressd); http_request.set_body(create_encode_request_body( codec.encoder(), Streaming::new(futures_util::stream::once(async move { Ok(message) })), + self.send_compressd, )); let mut resp = self @@ -218,7 +235,9 @@ impl GrpcClient { } let body = resp.take_body(); - let mut stream = create_decode_response_body(codec.decoder(), resp.headers(), body)?; + let incoming_encoding = get_incoming_encodings(resp.headers(), &self.accept_compressed)?; + let mut stream = + create_decode_response_body(codec.decoder(), resp.headers(), body, incoming_encoding)?; let message = stream .try_next() @@ -243,8 +262,13 @@ impl GrpcClient { message, extensions, } = request; - let mut http_request = create_http_request::(path, metadata, extensions); - http_request.set_body(create_encode_request_body(codec.encoder(), message)); + let mut http_request = + create_http_request::(path, metadata, extensions, self.send_compressd); + http_request.set_body(create_encode_request_body( + codec.encoder(), + message, + self.send_compressd, + )); let mut resp = self .ep @@ -260,7 +284,9 @@ impl GrpcClient { } let body = resp.take_body(); - let mut stream = create_decode_response_body(codec.decoder(), resp.headers(), body)?; + let incoming_encoding = get_incoming_encodings(resp.headers(), &self.accept_compressed)?; + let mut stream = + create_decode_response_body(codec.decoder(), resp.headers(), body, incoming_encoding)?; let message = stream .try_next() @@ -285,10 +311,12 @@ impl GrpcClient { message, extensions, } = request; - let mut http_request = create_http_request::(path, metadata, extensions); + let mut http_request = + create_http_request::(path, metadata, extensions, self.send_compressd); http_request.set_body(create_encode_request_body( codec.encoder(), Streaming::new(futures_util::stream::once(async move { Ok(message) })), + self.send_compressd, )); let mut resp = self @@ -305,7 +333,9 @@ impl GrpcClient { } let body = resp.take_body(); - let stream = create_decode_response_body(codec.decoder(), resp.headers(), body)?; + let incoming_encoding = get_incoming_encodings(resp.headers(), &self.accept_compressed)?; + let stream = + create_decode_response_body(codec.decoder(), resp.headers(), body, incoming_encoding)?; Ok(Response { metadata: Metadata { @@ -326,8 +356,13 @@ impl GrpcClient { message, extensions, } = request; - let mut http_request = create_http_request::(path, metadata, extensions); - http_request.set_body(create_encode_request_body(codec.encoder(), message)); + let mut http_request = + create_http_request::(path, metadata, extensions, self.send_compressd); + http_request.set_body(create_encode_request_body( + codec.encoder(), + message, + self.send_compressd, + )); let mut resp = self .ep @@ -343,7 +378,9 @@ impl GrpcClient { } let body = resp.take_body(); - let stream = create_decode_response_body(codec.decoder(), resp.headers(), body)?; + let incoming_encoding = get_incoming_encodings(resp.headers(), &self.accept_compressed)?; + let stream = + create_decode_response_body(codec.decoder(), resp.headers(), body, incoming_encoding)?; Ok(Response { metadata: Metadata { @@ -358,6 +395,7 @@ fn create_http_request( path: &str, metadata: Metadata, extensions: Extensions, + send_compressd: Option, ) -> HttpRequest { let mut http_request = HttpRequest::builder() .uri_str(path) @@ -368,6 +406,12 @@ fn create_http_request( .finish(); http_request.headers_mut().extend(metadata.headers); *http_request.extensions_mut() = extensions; + if let Some(send_compressd) = send_compressd { + http_request.headers_mut().insert( + "grpc-encoding", + HeaderValue::from_str(send_compressd.as_str()).expect("BUG: invalid encoding"), + ); + } http_request } diff --git a/poem-grpc/src/compression.rs b/poem-grpc/src/compression.rs new file mode 100644 index 0000000000..71fd738260 --- /dev/null +++ b/poem-grpc/src/compression.rs @@ -0,0 +1,186 @@ +use std::{io::Result as IoResult, str::FromStr}; + +use http::HeaderMap; + +use crate::{Code, Metadata, Status}; + +/// The compression encodings. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum CompressionEncoding { + /// gzip + #[cfg(feature = "gzip")] + #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))] + GZIP, + /// deflate + #[cfg(feature = "deflate")] + #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))] + DEFLATE, + /// brotli + #[cfg(feature = "brotli")] + #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))] + BROTLI, + /// zstd + #[cfg(feature = "zstd")] + #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))] + ZSTD, +} + +impl FromStr for CompressionEncoding { + type Err = (); + + #[inline] + fn from_str(s: &str) -> Result { + match s { + #[cfg(feature = "gzip")] + "gzip" => Ok(CompressionEncoding::GZIP), + #[cfg(feature = "deflate")] + "deflate" => Ok(CompressionEncoding::DEFLATE), + #[cfg(feature = "brotli")] + "br" => Ok(CompressionEncoding::BROTLI), + #[cfg(feature = "zstd")] + "zstd" => Ok(CompressionEncoding::ZSTD), + _ => Err(()), + } + } +} + +impl CompressionEncoding { + /// Returns the encoding name. + #[allow(unreachable_patterns)] + pub fn as_str(&self) -> &'static str { + match self { + #[cfg(feature = "gzip")] + CompressionEncoding::GZIP => "gzip", + #[cfg(feature = "deflate")] + CompressionEncoding::DEFLATE => "deflate", + #[cfg(feature = "brotli")] + CompressionEncoding::BROTLI => "br", + #[cfg(feature = "zstd")] + CompressionEncoding::ZSTD => "zstd", + _ => unreachable!(), + } + } + + #[allow( + unreachable_code, + unused_imports, + unused_mut, + unused_variables, + unreachable_patterns + )] + pub(crate) async fn encode(&self, data: &[u8]) -> IoResult> { + use tokio::io::AsyncReadExt; + + let mut buf = Vec::new(); + + match self { + #[cfg(feature = "gzip")] + CompressionEncoding::GZIP => { + async_compression::tokio::bufread::GzipEncoder::new(data) + .read_to_end(&mut buf) + .await?; + } + #[cfg(feature = "deflate")] + CompressionEncoding::DEFLATE => { + async_compression::tokio::bufread::DeflateEncoder::new(data) + .read_to_end(&mut buf) + .await?; + } + #[cfg(feature = "brotli")] + CompressionEncoding::BROTLI => { + async_compression::tokio::bufread::BrotliEncoder::new(data) + .read_to_end(&mut buf) + .await?; + } + #[cfg(feature = "zstd")] + CompressionEncoding::ZSTD => { + async_compression::tokio::bufread::ZstdEncoder::new(data) + .read_to_end(&mut buf) + .await?; + } + _ => unreachable!(), + } + + Ok(buf) + } + + #[allow( + unreachable_code, + unused_imports, + unused_mut, + unused_variables, + unreachable_patterns + )] + pub(crate) async fn decode(&self, data: &[u8]) -> IoResult> { + use tokio::io::AsyncReadExt; + + let mut buf = Vec::new(); + + match self { + #[cfg(feature = "gzip")] + CompressionEncoding::GZIP => { + async_compression::tokio::bufread::GzipDecoder::new(data) + .read_to_end(&mut buf) + .await?; + } + #[cfg(feature = "deflate")] + CompressionEncoding::DEFLATE => { + async_compression::tokio::bufread::DeflateDecoder::new(data) + .read_to_end(&mut buf) + .await?; + } + #[cfg(feature = "brotli")] + CompressionEncoding::BROTLI => { + async_compression::tokio::bufread::BrotliDecoder::new(data) + .read_to_end(&mut buf) + .await?; + } + #[cfg(feature = "zstd")] + CompressionEncoding::ZSTD => { + async_compression::tokio::bufread::ZstdDecoder::new(data) + .read_to_end(&mut buf) + .await?; + } + _ => unreachable!(), + } + + Ok(buf) + } +} + +fn unimplemented(accept_compressed: &[CompressionEncoding]) -> Status { + let mut md = Metadata::new(); + let mut accept_encoding = String::new(); + let mut iter = accept_compressed.iter(); + if let Some(encoding) = iter.next() { + accept_encoding.push_str(encoding.as_str()); + } + for encoding in iter { + accept_encoding.push_str(", "); + accept_encoding.push_str(encoding.as_str()); + } + md.append("grpc-accept-encoding", accept_encoding); + Status::new(Code::Unimplemented) + .with_metadata(md) + .with_message("unsupported encoding") +} + +pub(crate) fn get_incoming_encodings( + headers: &HeaderMap, + accept_compressed: &[CompressionEncoding], +) -> Result, Status> { + let Some(value) = headers.get("grpc-encoding") else { + return Ok(None); + }; + let Some(encoding) = value + .to_str() + .ok() + .and_then(|value| value.parse::().ok()) + else { + return Err(unimplemented(accept_compressed)); + }; + if !accept_compressed.contains(&encoding) { + return Err(unimplemented(accept_compressed)); + } + Ok(Some(encoding)) +} diff --git a/poem-grpc/src/encoding.rs b/poem-grpc/src/encoding.rs index 4a19411047..6a4e79c9bb 100644 --- a/poem-grpc/src/encoding.rs +++ b/poem-grpc/src/encoding.rs @@ -1,42 +1,60 @@ -use std::io::{Error as IoError, Result as IoResult}; +use std::io::Result as IoResult; use bytes::{Buf, BufMut, Bytes, BytesMut}; -use flate2::read::GzDecoder; use futures_util::StreamExt; use http_body_util::{BodyExt, StreamBody}; use hyper::{body::Frame, HeaderMap}; use poem::Body; -use tokio::sync::mpsc; -use tokio_stream::wrappers::ReceiverStream; +use sync_wrapper::SyncStream; use crate::{ client::BoxBody, codec::{Decoder, Encoder}, - Code, Status, Streaming, + Code, CompressionEncoding, Status, Streaming, }; -fn encode_data_frame( +async fn encode_data_frame( encoder: &mut T, buf: &mut BytesMut, message: T::Item, + compression: Option, ) -> IoResult { - buf.put_slice(&[0, 0, 0, 0, 0]); - encoder.encode(message, buf)?; + buf.put_slice(&[compression.is_some() as u8, 0, 0, 0, 0]); + + if let Some(compression) = compression { + let mut data = BytesMut::new(); + encoder.encode(message, &mut data)?; + let data = compression.encode(&data).await?; + buf.extend(data); + } else { + encoder.encode(message, buf)?; + } + let msg_len = (buf.len() - 5) as u32; buf.as_mut()[1..5].copy_from_slice(&msg_len.to_be_bytes()); Ok(buf.split().freeze()) } -#[derive(Default)] struct DataFrameDecoder { buf: BytesMut, + compression: Option, } impl DataFrameDecoder { + #[inline] + fn new(compression: Option) -> Self { + Self { + buf: BytesMut::new(), + compression, + } + } + + #[inline] fn put_slice(&mut self, data: impl AsRef<[u8]>) { self.buf.extend_from_slice(data.as_ref()); } + #[inline] fn check_incomplete(&self) -> Result<(), Status> { if !self.buf.is_empty() { return Err(Status::new(Code::Internal).with_message("incomplete request")); @@ -44,7 +62,7 @@ impl DataFrameDecoder { Ok(()) } - fn next(&mut self) -> Result, Status> { + async fn next(&mut self) -> Result, Status> { if self.buf.len() < 5 { return Ok(None); } @@ -62,11 +80,15 @@ impl DataFrameDecoder { let data = self.buf.split_to(len).freeze(); if compressed { - let mut decoder = GzDecoder::new(&*data); - let raw_data = BytesMut::new(); - let mut writer = raw_data.writer(); - std::io::copy(&mut decoder, &mut writer).map_err(Status::from_std_error)?; - Ok(Some(writer.into_inner().freeze())) + let compression = self.compression.ok_or_else(|| { + Status::new(Code::Unimplemented) + .with_message(format!("unsupported compressed flag: {compressed}")) + })?; + let data = compression + .decode(&data) + .await + .map_err(|err| Status::new(Code::Internal).with_message(err.to_string()))?; + Ok(Some(data.into())) } else { Ok(Some(data)) } @@ -79,18 +101,19 @@ impl DataFrameDecoder { pub(crate) fn create_decode_request_body( mut decoder: T, body: Body, + compression: Option, ) -> Streaming { let mut body: BoxBody = body.into(); Streaming::new(async_stream::try_stream! { - let mut frame_decoder = DataFrameDecoder::default(); + let mut frame_decoder = DataFrameDecoder::new(compression); loop { match body.frame().await.transpose().map_err(Status::from_std_error)? { Some(frame) => { if let Ok(data) = frame.into_data() { frame_decoder.put_slice(data); - while let Some(data) = frame_decoder.next()? { + while let Some(data) = frame_decoder.next().await? { let message = decoder.decode(&data).map_err(Status::from_std_error)?; yield message; } @@ -108,67 +131,53 @@ pub(crate) fn create_decode_request_body( pub(crate) fn create_encode_response_body( mut encoder: T, mut stream: Streaming, + compression: Option, ) -> Body { - let (tx, rx) = mpsc::channel(16); - - tokio::spawn(async move { + let stream = async_stream::try_stream! { let mut buf = BytesMut::new(); while let Some(item) = stream.next().await { match item { Ok(message) => { - if let Ok(data) = encode_data_frame(&mut encoder, &mut buf, message) { - if tx.send(Frame::data(data)).await.is_err() { - return; - } + if let Ok(data) = encode_data_frame(&mut encoder, &mut buf, message, compression).await { + yield Frame::data(data); } } Err(status) => { - _ = tx.send(Frame::trailers(status.to_headers())).await; - return; + yield Frame::trailers(status.to_headers()); } } } - _ = tx - .send(Frame::trailers(Status::new(Code::Ok).to_headers())) - .await; - }); + yield Frame::trailers(Status::new(Code::Ok).to_headers()); + }; - BodyExt::boxed(StreamBody::new( - ReceiverStream::new(rx).map(Ok::<_, IoError>), - )) - .into() + BodyExt::boxed(StreamBody::new(SyncStream::new(stream))).into() } pub(crate) fn create_encode_request_body( mut encoder: T, mut stream: Streaming, + compression: Option, ) -> Body { - let (tx, rx) = mpsc::channel(16); - - tokio::spawn(async move { + let stream = async_stream::try_stream! { let mut buf = BytesMut::new(); while let Some(Ok(message)) = stream.next().await { - if let Ok(data) = encode_data_frame(&mut encoder, &mut buf, message) { - if tx.send(Frame::data(data)).await.is_err() { - return; - } + if let Ok(data) = encode_data_frame(&mut encoder, &mut buf, message, compression).await { + yield Frame::data(data); } } - }); + }; - BodyExt::boxed(StreamBody::new( - ReceiverStream::new(rx).map(Ok::<_, IoError>), - )) - .into() + BodyExt::boxed(StreamBody::new(SyncStream::new(stream))).into() } pub(crate) fn create_decode_response_body( mut decoder: T, headers: &HeaderMap, body: Body, + compression: Option, ) -> Result, Status> { // check is trailers-only if let Some(status) = Status::from_headers(headers)? { @@ -182,14 +191,14 @@ pub(crate) fn create_decode_response_body( let mut body: BoxBody = body.into(); Ok(Streaming::new(async_stream::try_stream! { - let mut frame_decoder = DataFrameDecoder::default(); + let mut frame_decoder = DataFrameDecoder::new(compression); let mut status = None; while let Some(frame) = body.frame().await.transpose().map_err(Status::from_std_error)? { if frame.is_data() { let data = frame.into_data().unwrap(); frame_decoder.put_slice(data); - while let Some(data) = frame_decoder.next()? { + while let Some(data) = frame_decoder.next().await? { let message = decoder.decode(&data).map_err(Status::from_std_error)?; yield message; } @@ -254,7 +263,7 @@ mod tests { let mut codec = ProstCodec::::default(); let mut streaming = - create_decode_response_body(codec.decoder(), &HeaderMap::default(), body) + create_decode_response_body(codec.decoder(), &HeaderMap::default(), body, None) .expect("streaming"); let stream_msg = streaming diff --git a/poem-grpc/src/example_generated/mod.rs b/poem-grpc/src/example_generated/mod.rs new file mode 100644 index 0000000000..1c00269490 --- /dev/null +++ b/poem-grpc/src/example_generated/mod.rs @@ -0,0 +1,6 @@ +//! This module shows an example of code generated by the macro. **IT MUST NOT BE USED OUTSIDE THIS +//! CRATE**. + +#![allow(missing_docs)] + +include!(concat!(env!("OUT_DIR"), "/routeguide.rs")); diff --git a/poem-grpc/src/example_generated/routeguide.proto b/poem-grpc/src/example_generated/routeguide.proto new file mode 100644 index 0000000000..0f54eeaace --- /dev/null +++ b/poem-grpc/src/example_generated/routeguide.proto @@ -0,0 +1,110 @@ +// Copyright 2015 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +option java_multiple_files = true; +option java_package = "io.grpc.examples.routeguide"; +option java_outer_classname = "RouteGuideProto"; + +package routeguide; + +// Interface exported by the server. +service RouteGuide { + // A simple RPC. + // + // Obtains the feature at a given position. + // + // A feature with an empty name is returned if there's no feature at the given + // position. + rpc GetFeature(Point) returns (Feature) {} + + // A server-to-client streaming RPC. + // + // Obtains the Features available within the given Rectangle. Results are + // streamed rather than returned at once (e.g. in a response message with a + // repeated field), as the rectangle may cover a large area and contain a + // huge number of features. + rpc ListFeatures(Rectangle) returns (stream Feature) {} + + // A client-to-server streaming RPC. + // + // Accepts a stream of Points on a route being traversed, returning a + // RouteSummary when traversal is completed. + rpc RecordRoute(stream Point) returns (RouteSummary) {} + + // A Bidirectional streaming RPC. + // + // Accepts a stream of RouteNotes sent while a route is being traversed, + // while receiving other RouteNotes (e.g. from other users). + rpc RouteChat(stream RouteNote) returns (stream RouteNote) {} +} + +// Points are represented as latitude-longitude pairs in the E7 representation +// (degrees multiplied by 10**7 and rounded to the nearest integer). +// Latitudes should be in the range +/- 90 degrees and longitude should be in +// the range +/- 180 degrees (inclusive). +message Point { + int32 latitude = 1; + int32 longitude = 2; +} + +// A latitude-longitude rectangle, represented as two diagonally opposite +// points "lo" and "hi". +message Rectangle { + // One corner of the rectangle. + Point lo = 1; + + // The other corner of the rectangle. + Point hi = 2; +} + +// A feature names something at a given point. +// +// If a feature could not be named, the name is empty. +message Feature { + // The name of the feature. + string name = 1; + + // The point where the feature is detected. + Point location = 2; +} + +// A RouteNote is a message sent while at a given point. +message RouteNote { + // The location from which the message is sent. + Point location = 1; + + // The message to be sent. + string message = 2; +} + +// A RouteSummary is received in response to a RecordRoute rpc. +// +// It contains the number of individual points received, the number of +// detected features, and the total distance covered as the cumulative sum of +// the distance between each point. +message RouteSummary { + // The number of points received. + int32 point_count = 1; + + // The number of known features passed while traversing the route. + int32 feature_count = 2; + + // The distance covered in metres. + int32 distance = 3; + + // The duration of the traversal in seconds. + int32 elapsed_time = 4; +} \ No newline at end of file diff --git a/poem-grpc/src/lib.rs b/poem-grpc/src/lib.rs index 67cb9e13bf..c3fbbbaecb 100644 --- a/poem-grpc/src/lib.rs +++ b/poem-grpc/src/lib.rs @@ -20,8 +20,11 @@ pub mod service; pub mod codec; pub mod metadata; +mod compression; mod connector; mod encoding; +#[cfg(feature = "example_generated")] +pub mod example_generated; mod health; mod reflection; mod request; @@ -33,6 +36,7 @@ mod streaming; mod test_harness; pub use client::{ClientBuilderError, ClientConfig, ClientConfigBuilder}; +pub use compression::CompressionEncoding; pub use health::{health_service, HealthReporter, ServingStatus}; pub use metadata::Metadata; pub use reflection::Reflection; diff --git a/poem-grpc/src/server.rs b/poem-grpc/src/server.rs index d2ead0a91d..642c31ac4e 100644 --- a/poem-grpc/src/server.rs +++ b/poem-grpc/src/server.rs @@ -1,32 +1,54 @@ use futures_util::StreamExt; -use poem::{Request, Response}; +use http::HeaderValue; +use poem::{Body, Request, Response}; use crate::{ codec::Codec, + compression::get_incoming_encodings, encoding::{create_decode_request_body, create_encode_response_body}, service::{ BidirectionalStreamingService, ClientStreamingService, ServerStreamingService, UnaryService, }, - Code, Metadata, Request as GrpcRequest, Response as GrpcResponse, Status, Streaming, + Code, CompressionEncoding, Metadata, Request as GrpcRequest, Response as GrpcResponse, Status, + Streaming, }; #[doc(hidden)] -pub struct GrpcServer { +pub struct GrpcServer<'a, T> { codec: T, + send_compressd: Option, + accept_compressed: &'a [CompressionEncoding], } -impl GrpcServer { +impl<'a, T: Codec> GrpcServer<'a, T> { #[inline] - pub fn new(codec: T) -> Self { - Self { codec } + pub fn new( + codec: T, + send_compressd: Option, + accept_compressed: &'a [CompressionEncoding], + ) -> Self { + Self { + codec, + send_compressd, + accept_compressed, + } } - pub async fn unary(&mut self, service: S, request: Request) -> Response + pub async fn unary(mut self, service: S, request: Request) -> Response where S: UnaryService, { let (parts, body) = request.into_parts(); - let mut stream = create_decode_request_body(self.codec.decoder(), body); + let mut resp = Response::default().set_content_type(T::CONTENT_TYPES[0]); + let incoming_encoding = + match get_incoming_encodings(&parts.headers, &self.accept_compressed) { + Ok(incoming_encoding) => incoming_encoding, + Err(status) => { + resp.headers_mut().extend(status.to_headers()); + return resp; + } + }; + let mut stream = create_decode_request_body(self.codec.decoder(), body, incoming_encoding); let res = match stream.next().await { Some(Ok(message)) => { @@ -44,32 +66,37 @@ impl GrpcServer { None => Err(Status::new(Code::Internal).with_message("missing request message")), }; - let mut resp = Response::default().set_content_type(T::CONTENT_TYPES[0]); - match res { Ok(grpc_resp) => { let GrpcResponse { metadata, message } = grpc_resp; let body = create_encode_response_body( self.codec.encoder(), Streaming::new(futures_util::stream::once(async move { Ok(message) })), + self.send_compressd, ); - resp.headers_mut().extend(metadata.headers); - resp.set_body(body); - } - Err(status) => { - resp.headers_mut().extend(status.to_headers()); + update_http_response(&mut resp, metadata, body, self.send_compressd); } + Err(status) => resp.headers_mut().extend(status.to_headers()), } resp } - pub async fn client_streaming(&mut self, service: S, request: Request) -> Response + pub async fn client_streaming(mut self, service: S, request: Request) -> Response where S: ClientStreamingService, { let (parts, body) = request.into_parts(); - let stream = create_decode_request_body(self.codec.decoder(), body); + let mut resp = Response::default().set_content_type(T::CONTENT_TYPES[0]); + let incoming_encoding = + match get_incoming_encodings(&parts.headers, &self.accept_compressed) { + Ok(incoming_encoding) => incoming_encoding, + Err(status) => { + resp.headers_mut().extend(status.to_headers()); + return resp; + } + }; + let stream = create_decode_request_body(self.codec.decoder(), body, incoming_encoding); let res = service .call(GrpcRequest { @@ -81,17 +108,15 @@ impl GrpcServer { }) .await; - let mut resp = Response::default().set_content_type(T::CONTENT_TYPES[0]); - match res { Ok(grpc_resp) => { let GrpcResponse { metadata, message } = grpc_resp; let body = create_encode_response_body( self.codec.encoder(), Streaming::new(futures_util::stream::once(async move { Ok(message) })), + self.send_compressd, ); - resp.headers_mut().extend(metadata.headers); - resp.set_body(body); + update_http_response(&mut resp, metadata, body, self.send_compressd); } Err(status) => { resp.headers_mut().extend(status.to_headers()); @@ -101,12 +126,21 @@ impl GrpcServer { resp } - pub async fn server_streaming(&mut self, service: S, request: Request) -> Response + pub async fn server_streaming(mut self, service: S, request: Request) -> Response where S: ServerStreamingService, { let (parts, body) = request.into_parts(); - let mut stream = create_decode_request_body(self.codec.decoder(), body); + let mut resp = Response::default().set_content_type(T::CONTENT_TYPES[0]); + let incoming_encoding = + match get_incoming_encodings(&parts.headers, &self.accept_compressed) { + Ok(incoming_encoding) => incoming_encoding, + Err(status) => { + resp.headers_mut().extend(status.to_headers()); + return resp; + } + }; + let mut stream = create_decode_request_body(self.codec.decoder(), body, incoming_encoding); let res = match stream.next().await { Some(Ok(message)) => { @@ -124,14 +158,12 @@ impl GrpcServer { None => Err(Status::new(Code::Internal).with_message("missing request message")), }; - let mut resp = Response::default().set_content_type(T::CONTENT_TYPES[0]); - match res { Ok(grpc_resp) => { let GrpcResponse { metadata, message } = grpc_resp; - let body = create_encode_response_body(self.codec.encoder(), message); - resp.headers_mut().extend(metadata.headers); - resp.set_body(body); + let body = + create_encode_response_body(self.codec.encoder(), message, self.send_compressd); + update_http_response(&mut resp, metadata, body, self.send_compressd); } Err(status) => { resp.headers_mut().extend(status.to_headers()); @@ -141,12 +173,21 @@ impl GrpcServer { resp } - pub async fn bidirectional_streaming(&mut self, service: S, request: Request) -> Response + pub async fn bidirectional_streaming(mut self, service: S, request: Request) -> Response where S: BidirectionalStreamingService, { let (parts, body) = request.into_parts(); - let stream = create_decode_request_body(self.codec.decoder(), body); + let mut resp = Response::default().set_content_type(T::CONTENT_TYPES[0]); + let incoming_encoding = + match get_incoming_encodings(&parts.headers, &self.accept_compressed) { + Ok(incoming_encoding) => incoming_encoding, + Err(status) => { + resp.headers_mut().extend(status.to_headers()); + return resp; + } + }; + let stream = create_decode_request_body(self.codec.decoder(), body, incoming_encoding); let res = service .call(GrpcRequest { @@ -158,14 +199,12 @@ impl GrpcServer { }) .await; - let mut resp = Response::default().set_content_type(T::CONTENT_TYPES[0]); - match res { Ok(grpc_resp) => { let GrpcResponse { metadata, message } = grpc_resp; - let body = create_encode_response_body(self.codec.encoder(), message); - resp.headers_mut().extend(metadata.headers); - resp.set_body(body); + let body = + create_encode_response_body(self.codec.encoder(), message, self.send_compressd); + update_http_response(&mut resp, metadata, body, self.send_compressd); } Err(status) => { resp.headers_mut().extend(status.to_headers()); @@ -175,3 +214,19 @@ impl GrpcServer { resp } } + +fn update_http_response( + resp: &mut Response, + metadata: Metadata, + body: Body, + send_compressd: Option, +) { + resp.headers_mut().extend(metadata.headers); + if let Some(send_compressd) = send_compressd { + resp.headers_mut().insert( + "grpc-encoding", + HeaderValue::from_str(send_compressd.as_str()).expect("BUG: invalid encoding"), + ); + } + resp.set_body(body); +}