diff --git a/Cargo.toml b/Cargo.toml index 85bd1bba..59c9bd54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,9 +45,10 @@ futures = "0.3.28" futures-core = "0.3.4" futures-util = "0.3.28" generic-json = "^0.7" -http = "0.2.9" -http-body = "0.4.5" -hyper = "0.14.27" +http = "1.0.0" +http-body = "1.0.0" +http-body-util = "0.1.0" +hyper = "1.0.1" image = "0.24.3" iref = "^3.1.2" lazy_static = "1.4.0" @@ -56,8 +57,8 @@ paho-mqtt = "0.12" parking_lot = "0.12.1" prost = "0.12" prost-types = "0.12" -regex = " 1.9.3" -sdl2 = "0.35.2" +regex = " 1.10.2" +sdl2 = "0.36.0" serde = "1.0.160" serde_derive = "1.0.163" serde_json = "^1.0" @@ -68,7 +69,7 @@ tokio-stream = "0.1.14" tonic = "0.10.0" tonic-build = "0.10.0" tower = "0.4.13" -tower-http = "0.4.3" +tower-http = "0.5.0" url = "2.3.1" uuid = "1.2.2" yaml-rust = "0.4" diff --git a/core/common/Cargo.toml b/core/common/Cargo.toml index fa15c191..94e98fe1 100644 --- a/core/common/Cargo.toml +++ b/core/common/Cargo.toml @@ -18,6 +18,7 @@ futures-core = { workspace = true } futures-util = { workspace = true } http = { workspace = true } http-body = { workspace = true } +http-body-util = { workspace = true } hyper = { workspace = true } log = { workspace = true } parking_lot = { workspace = true } diff --git a/core/common/src/grpc_interceptor.rs b/core/common/src/grpc_interceptor.rs index ecf5906d..0eef0ccc 100644 --- a/core/common/src/grpc_interceptor.rs +++ b/core/common/src/grpc_interceptor.rs @@ -14,6 +14,8 @@ use std::error::Error; use std::pin::Pin; use tower::{Layer, Service}; +use crate::utils; + // This module provides the gRPC Interceptor construct. It can be used to // intercept gRPC calls, and examine/modify their requests and responses. @@ -159,7 +161,7 @@ where if is_applicable && interceptor.must_handle_request() { let (parts, body) = request.into_parts(); let mut body_bytes: Bytes = - match futures::executor::block_on(hyper::body::to_bytes(body)) { + match futures::executor::block_on(utils::to_bytes(&mut body, None)) { Ok(bytes) => bytes, Err(err) => { return Box::pin(async move { @@ -191,7 +193,7 @@ where if is_applicable && interceptor.must_handle_response() { let (parts, body) = response.into_parts(); - let mut body_bytes = match hyper::body::to_bytes(body).await { + let mut body_bytes = match utils::to_bytes(&mut body, None).await { Ok(bytes) => bytes, Err(err) => { return Err(Box::new(err) as Box) diff --git a/core/common/src/utils.rs b/core/common/src/utils.rs index ab732d06..fe03796c 100644 --- a/core/common/src/utils.rs +++ b/core/common/src/utils.rs @@ -4,10 +4,13 @@ #![allow(unused_imports)] +use bytes::{Bytes, BytesMut}; use config::{Config, ConfigError, File, FileFormat}; use core_protobuf_data_access::chariott::service_discovery::core::v1::{ service_registry_client::ServiceRegistryClient, DiscoverRequest, }; +use http_body::Body; +use http_body_util::{BodyExt, combinators::UnsyncBoxBody}; use log::{debug, info}; use serde_derive::Deserialize; use std::env; @@ -218,6 +221,33 @@ pub fn get_uri(uri: &str) -> Result { Ok(uri.to_string()) } +/// Converts an HTTP body to bytes, propagating errors from the body. +/// +/// # Arguments +/// - `body`: the body to read +/// - `max_length`: an optional maximum number of bytes to read. Body frames will be read until this value is exceeded. Setting this value can help avoid DoS attacks. +// pub async fn to_bytes<'a, T, E>(body: &mut T, max_length: Option) -> Result +// where +// T: Body + Unpin +pub async fn to_bytes(body: &mut UnsyncBoxBody, max_length: Option) -> Result +{ + let mut buf = BytesMut::new(); + while let Some(next) = body.frame().await { + let frame = next?; + + // Only capture DATA frames and skip others, such as trailer frames + if let Some(chunk) = frame.data_ref() { + buf.extend_from_slice(&chunk[..]); + } + + if buf.len() >= max_length.unwrap_or(usize::MAX) { + return Ok(buf.freeze()); + } + } + + Ok(buf.freeze()) +} + #[cfg(test)] mod tests { use super::*;