From 80ad2cfe1848139254c09cd979b9d646629053a7 Mon Sep 17 00:00:00 2001 From: laststylebender Date: Tue, 3 Dec 2024 13:38:17 +0530 Subject: [PATCH] - code clean up --- src/core/http/data_loader.rs | 90 ++----- src/core/http/mod.rs | 1 + .../http/transformations/body_batching.rs | 248 ++++++++++++++++++ src/core/http/transformations/mod.rs | 5 + .../http/transformations/query_batching.rs | 200 ++++++++++++++ 5 files changed, 473 insertions(+), 71 deletions(-) create mode 100644 src/core/http/transformations/body_batching.rs create mode 100644 src/core/http/transformations/mod.rs create mode 100644 src/core/http/transformations/query_batching.rs diff --git a/src/core/http/data_loader.rs b/src/core/http/data_loader.rs index acbbd68e59..845039508e 100644 --- a/src/core/http/data_loader.rs +++ b/src/core/http/data_loader.rs @@ -6,14 +6,17 @@ use std::time::Duration; use async_graphql::async_trait; use async_graphql::futures_util::future::join_all; use async_graphql_value::ConstValue; -use reqwest::Request; +use tailcall_valid::Validator; +use super::transformations::{BodyBatching, QueryBatching}; use crate::core::config::group_by::GroupBy; use crate::core::config::Batch; use crate::core::data_loader::{DataLoader, Loader}; use crate::core::http::{DataLoaderRequest, Response}; use crate::core::json::JsonLike; use crate::core::runtime::TargetRuntime; +use crate::core::transform::TransformerOps; +use crate::core::Transform; fn get_body_value_single(body_value: &HashMap>, id: &str) -> ConstValue { body_value @@ -58,58 +61,6 @@ fn get_key<'a, T: JsonLike<'a> + Display>(value: &'a T, path: &[String]) -> anyh .ok_or_else(|| anyhow::anyhow!("Unable to find key {} in body", path.join("."))) } -/// This function is used to batch the body of the requests. -/// working of this function is as follows: -/// 1. It takes the list of requests and extracts the body from each request. -/// 2. It then clubs all the extracted bodies into list format. like [body1, -/// body2, body3] -/// 3. It does this all manually to avoid extra serialization cost. -fn batch_request_body(mut base_request: Request, requests: &[DataLoaderRequest]) -> Request { - let mut request_bodies = Vec::with_capacity(requests.len()); - - if base_request.method() == reqwest::Method::GET { - // in case of GET method do nothing and return the base request. - return base_request; - } - - for req in requests { - if let Some(body) = req.body().and_then(|b| b.as_bytes()) { - request_bodies.push(body); - } - } - - if !request_bodies.is_empty() { - if cfg!(feature = "integration_test") || cfg!(test) { - // sort the body to make it consistent for testing env. - request_bodies.sort(); - } - - // construct serialization manually. - let merged_body = request_bodies.iter().fold( - Vec::with_capacity( - request_bodies.iter().map(|i| i.len()).sum::() + request_bodies.len(), - ), - |mut acc, item| { - if !acc.is_empty() { - // add ',' to separate the body from each other. - acc.extend_from_slice(b","); - } - acc.extend_from_slice(item); - acc - }, - ); - - // add list brackets to the serialized body. - let mut serialized_body = Vec::with_capacity(merged_body.len() + 2); - serialized_body.extend_from_slice(b"["); - serialized_body.extend_from_slice(&merged_body); - serialized_body.extend_from_slice(b"]"); - base_request.body_mut().replace(serialized_body.into()); - } - - base_request -} - #[async_trait::async_trait] impl Loader for HttpDataLoader { type Value = Response; @@ -128,24 +79,21 @@ impl Loader for HttpDataLoader { } if let Some(base_dl_request) = dl_requests.first().as_mut() { - // Create base request - let mut base_request = - batch_request_body(base_dl_request.to_request(), &dl_requests); - - // Merge query params in the request - for key in dl_requests.iter().skip(1) { - let request = key.to_request(); - let url = request.url(); - let pairs: Vec<_> = url - .query_pairs() - .filter(|(key, _)| group_by.key().eq(&key.to_string())) - .collect(); - if !pairs.is_empty() { - // if pair's are empty then don't extend the query params else it ends - // up appending '?' to the url. - base_request.url_mut().query_pairs_mut().extend_pairs(pairs); - } - } + let base_request = if base_dl_request.method() == http::Method::GET { + QueryBatching::new( + &dl_requests.iter().skip(1).collect::>(), + Some(group_by.key()), + ) + .transform(base_dl_request.to_request()) + .to_result() + .map_err(|e| anyhow::anyhow!(e))? + } else { + QueryBatching::new(&dl_requests.iter().skip(1).collect::>(), None) + .pipe(BodyBatching::new(&dl_requests.iter().collect::>())) + .transform(base_dl_request.to_request()) + .to_result() + .map_err(|e| anyhow::anyhow!(e))? + }; // Dispatch request let res = self diff --git a/src/core/http/mod.rs b/src/core/http/mod.rs index 69967ed098..62c50094ce 100644 --- a/src/core/http/mod.rs +++ b/src/core/http/mod.rs @@ -20,6 +20,7 @@ mod request_template; mod response; pub mod showcase; mod telemetry; +mod transformations; pub static TAILCALL_HTTPS_ORIGIN: HeaderValue = HeaderValue::from_static("https://tailcall.run"); pub static TAILCALL_HTTP_ORIGIN: HeaderValue = HeaderValue::from_static("http://tailcall.run"); diff --git a/src/core/http/transformations/body_batching.rs b/src/core/http/transformations/body_batching.rs new file mode 100644 index 0000000000..62b361185a --- /dev/null +++ b/src/core/http/transformations/body_batching.rs @@ -0,0 +1,248 @@ +use std::convert::Infallible; + +use reqwest::Request; +use tailcall_valid::Valid; + +use crate::core::http::DataLoaderRequest; +use crate::core::Transform; + +pub struct BodyBatching<'a> { + dl_requests: &'a [&'a DataLoaderRequest], +} + +impl<'a> BodyBatching<'a> { + pub fn new(dl_requests: &'a [&'a DataLoaderRequest]) -> Self { + BodyBatching { dl_requests } + } +} + +impl Transform for BodyBatching<'_> { + type Value = Request; + type Error = Infallible; + + // This function is used to batch the body of the requests. + // working of this function is as follows: + // 1. It takes the list of requests and extracts the body from each request. + // 2. It then clubs all the extracted bodies into list format. like [body1, + // body2, body3] + // 3. It does this all manually to avoid extra serialization cost. + fn transform(&self, mut base_request: Self::Value) -> Valid { + let mut request_bodies = Vec::with_capacity(self.dl_requests.len()); + + for req in self.dl_requests { + if let Some(body) = req.body().and_then(|b| b.as_bytes()) { + request_bodies.push(body); + } + } + + if !request_bodies.is_empty() { + if cfg!(feature = "integration_test") || cfg!(test) { + // sort the body to make it consistent for testing env. + request_bodies.sort(); + } + + // construct serialization manually. + let merged_body = request_bodies.iter().fold( + Vec::with_capacity( + request_bodies.iter().map(|i| i.len()).sum::() + request_bodies.len(), + ), + |mut acc, item| { + if !acc.is_empty() { + // add ',' to separate the body from each other. + acc.extend_from_slice(b","); + } + acc.extend_from_slice(item); + acc + }, + ); + + // add list brackets to the serialized body. + let mut serialized_body = Vec::with_capacity(merged_body.len() + 2); + serialized_body.extend_from_slice(b"["); + serialized_body.extend_from_slice(&merged_body); + serialized_body.extend_from_slice(b"]"); + base_request.body_mut().replace(serialized_body.into()); + } + + Valid::succeed(base_request) + } +} + +#[cfg(test)] +mod tests { + use http::Method; + use reqwest::Request; + use serde_json::json; + use tailcall_valid::Validator; + + use super::*; + use crate::core::http::DataLoaderRequest; + + fn create_request(body: Option) -> DataLoaderRequest { + let mut request = create_base_request(); + if let Some(body) = body { + let bytes_body = serde_json::to_vec(&body).unwrap(); + request.body_mut().replace(reqwest::Body::from(bytes_body)); + } + + DataLoaderRequest::new(request, Default::default()) + } + + fn create_base_request() -> Request { + Request::new(Method::POST, "http://example.com".parse().unwrap()) + } + + #[test] + fn test_empty_requests() { + let requests: Vec<&DataLoaderRequest> = vec![]; + let base_request = create_base_request(); + + let result = BodyBatching::new(&requests) + .transform(base_request) + .to_result() + .unwrap(); + + assert!(result.body().is_none()); + } + + #[test] + fn test_single_request() { + let req = create_request(Some(json!({"id": 1}))); + let requests = vec![&req]; + let base_request = create_base_request(); + + let request = BodyBatching::new(&requests) + .transform(base_request) + .to_result() + .unwrap(); + + let bytes = request + .body() + .and_then(|b| b.as_bytes()) + .unwrap_or_default(); + let body_str = String::from_utf8(bytes.to_vec()).unwrap(); + assert_eq!(body_str, r#"[{"id":1}]"#); + } + + #[test] + fn test_multiple_requests() { + let req1 = create_request(Some(json!({"id": 1}))); + let req2 = create_request(Some(json!({"id": 2}))); + let requests = vec![&req1, &req2]; + let base_request = create_base_request(); + + let result = BodyBatching::new(&requests) + .transform(base_request) + .to_result() + .unwrap(); + + let body = result.body().and_then(|b| b.as_bytes()).unwrap(); + let body_str = String::from_utf8(body.to_vec()).unwrap(); + assert_eq!(body_str, r#"[{"id":1},{"id":2}]"#); + } + + #[test] + fn test_requests_with_empty_bodies() { + let req1 = create_request(Some(json!({"id": 1}))); + let req2 = create_request(None); + let req3 = create_request(Some(json!({"id": 3}))); + let requests = vec![&req1, &req2, &req3]; + let base_request = create_base_request(); + + let result = BodyBatching::new(&requests) + .transform(base_request) + .to_result() + .unwrap(); + + let body_bytes = result + .body() + .and_then(|b| b.as_bytes()) + .expect("Body should be present"); + let parsed: Vec = serde_json::from_slice(body_bytes).unwrap(); + + assert_eq!(parsed.len(), 2); + assert_eq!(parsed[0]["id"], 1); + assert_eq!(parsed[1]["id"], 3); + } + + #[test] + #[cfg(test)] + fn test_body_sorting_in_test_env() { + let req1 = create_request(Some(json!({ + "id": 2, + "value": "second" + }))); + let req2 = create_request(Some(json!({ + "id": 1, + "value": "first" + }))); + let requests = vec![&req1, &req2]; + let base_request = create_base_request(); + + let result = BodyBatching::new(&requests) + .transform(base_request) + .to_result() + .unwrap(); + + let body_bytes = result + .body() + .and_then(|b| b.as_bytes()) + .expect("Body should be present"); + let parsed: Vec = serde_json::from_slice(body_bytes).unwrap(); + + // Verify sorting by comparing the serialized form + assert_eq!(parsed.len(), 2); + assert_eq!(parsed[0]["id"], 1); + assert_eq!(parsed[0]["value"], "first"); + assert_eq!(parsed[1]["id"], 2); + assert_eq!(parsed[1]["value"], "second"); + } + + #[test] + fn test_complex_json_bodies() { + let req1 = create_request(Some(json!({ + "id": 1, + "nested": { + "array": [1, 2, 3], + "object": {"key": "value"} + }, + "tags": ["a", "b", "c"] + }))); + let req2 = create_request(Some(json!({ + "id": 2, + "nested": { + "array": [4, 5, 6], + "object": {"key": "another"} + }, + "tags": ["x", "y", "z"] + }))); + let requests = vec![&req1, &req2]; + let base_request = create_base_request(); + + let result = BodyBatching::new(&requests) + .transform(base_request) + .to_result() + .unwrap(); + + let body_bytes = result + .body() + .and_then(|b| b.as_bytes()) + .expect("Body should be present"); + let parsed: Vec = serde_json::from_slice(body_bytes).unwrap(); + + // Verify structure and content of both objects + assert_eq!(parsed.len(), 2); + + // First object + assert_eq!(parsed[0]["id"], 1); + assert_eq!(parsed[0]["nested"]["array"], json!([1, 2, 3])); + assert_eq!(parsed[0]["nested"]["object"]["key"], "value"); + assert_eq!(parsed[0]["tags"], json!(["a", "b", "c"])); + + // Second object + assert_eq!(parsed[1]["id"], 2); + assert_eq!(parsed[1]["nested"]["array"], json!([4, 5, 6])); + assert_eq!(parsed[1]["nested"]["object"]["key"], "another"); + assert_eq!(parsed[1]["tags"], json!(["x", "y", "z"])); + } +} diff --git a/src/core/http/transformations/mod.rs b/src/core/http/transformations/mod.rs new file mode 100644 index 0000000000..b6ab71810c --- /dev/null +++ b/src/core/http/transformations/mod.rs @@ -0,0 +1,5 @@ +mod body_batching; +mod query_batching; + +pub use body_batching::BodyBatching; +pub use query_batching::QueryBatching; diff --git a/src/core/http/transformations/query_batching.rs b/src/core/http/transformations/query_batching.rs new file mode 100644 index 0000000000..1612608388 --- /dev/null +++ b/src/core/http/transformations/query_batching.rs @@ -0,0 +1,200 @@ +use std::convert::Infallible; + +use reqwest::Request; +use tailcall_valid::Valid; + +use crate::core::http::DataLoaderRequest; +use crate::core::Transform; + +pub struct QueryBatching<'a> { + dl_requests: &'a [&'a DataLoaderRequest], + group_by: Option<&'a str>, +} + +impl<'a> QueryBatching<'a> { + pub fn new(dl_requests: &'a [&'a DataLoaderRequest], group_by: Option<&'a str>) -> Self { + QueryBatching { dl_requests, group_by } + } +} + +impl Transform for QueryBatching<'_> { + type Value = Request; + type Error = Infallible; + fn transform(&self, mut base_request: Self::Value) -> Valid { + // Merge query params in the request + for key in self.dl_requests.iter() { + let request = key.to_request(); + let url = request.url(); + let pairs: Vec<_> = if let Some(group_by_key) = self.group_by { + url.query_pairs() + .filter(|(key, _)| group_by_key.eq(&key.to_string())) + .collect() + } else { + url.query_pairs().collect() + }; + + if !pairs.is_empty() { + // if pair's are empty then don't extend the query params else it ends + // up appending '?' to the url. + base_request.url_mut().query_pairs_mut().extend_pairs(pairs); + } + } + Valid::succeed(base_request) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use http::Method; + use reqwest::Url; + use tailcall_valid::Validator; + + use super::*; + + fn create_base_request() -> Request { + Request::new(Method::GET, "http://example.com".parse().unwrap()) + } + + fn create_request_with_params(params: &[(&str, &str)]) -> DataLoaderRequest { + let mut url = Url::parse("http://example.com").unwrap(); + { + let mut query_pairs = url.query_pairs_mut(); + for (key, value) in params { + query_pairs.append_pair(key, value); + } + } + let request = Request::new(Method::GET, url); + DataLoaderRequest::new(request, Default::default()) + } + + fn get_query_params(request: &Request) -> HashMap { + request + .url() + .query_pairs() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect() + } + + #[test] + fn test_empty_requests() { + let requests: Vec<&DataLoaderRequest> = vec![]; + let base_request = create_base_request(); + + let result = QueryBatching::new(&requests, None) + .transform(base_request) + .to_result() + .unwrap(); + + assert!(result.url().query().is_none()); + } + + #[test] + fn test_single_request_no_grouping() { + let req = create_request_with_params(&[("id", "1"), ("name", "test")]); + let requests = vec![&req]; + let base_request = create_base_request(); + + let result = QueryBatching::new(&requests, None) + .transform(base_request) + .to_result() + .unwrap(); + + let params = get_query_params(&result); + assert_eq!(params.len(), 2); + assert_eq!(params.get("id").unwrap(), "1"); + assert_eq!(params.get("name").unwrap(), "test"); + } + + #[test] + fn test_multiple_requests_with_grouping() { + let req1 = create_request_with_params(&[("user_id", "1"), ("extra", "data1")]); + let req2 = create_request_with_params(&[("user_id", "2"), ("extra", "data2")]); + let requests = vec![&req1, &req2]; + let base_request = create_base_request(); + + let result = QueryBatching::new(&requests, Some("user_id")) + .transform(base_request) + .to_result() + .unwrap(); + + let params = get_query_params(&result); + assert!(params.contains_key("user_id")); + assert!(!params.contains_key("extra")); + + // URL should contain both user_ids + let url = result.url().to_string(); + assert!(url.contains("user_id=1")); + assert!(url.contains("user_id=2")); + } + + #[test] + fn test_multiple_requests_no_grouping() { + let req1 = create_request_with_params(&[("param1", "value1"), ("shared", "a")]); + let req2 = create_request_with_params(&[("param2", "value2"), ("shared", "b")]); + let requests = vec![&req1, &req2]; + let base_request = create_base_request(); + + let result = QueryBatching::new(&requests, None) + .transform(base_request) + .to_result() + .unwrap(); + + let params = get_query_params(&result); + assert_eq!(params.get("param1").unwrap(), "value1"); + assert_eq!(params.get("param2").unwrap(), "value2"); + assert_eq!(params.get("shared").unwrap(), "b"); + } + + #[test] + fn test_requests_with_empty_params() { + let req1 = create_request_with_params(&[("id", "1")]); + let req2 = create_request_with_params(&[]); + let req3 = create_request_with_params(&[("id", "3")]); + let requests = vec![&req1, &req2, &req3]; + let base_request = create_base_request(); + + let result = QueryBatching::new(&requests, Some("id")) + .transform(base_request) + .to_result() + .unwrap(); + + let url = result.url().to_string(); + assert!(url.contains("id=1")); + assert!(url.contains("id=3")); + } + + #[test] + fn test_special_characters() { + let req1 = create_request_with_params(&[("query", "hello world"), ("tag", "a+b")]); + let req2 = create_request_with_params(&[("query", "foo&bar"), ("tag", "c%20d")]); + let requests = vec![&req1, &req2]; + let base_request = create_base_request(); + + let result = QueryBatching::new(&requests, None) + .transform(base_request) + .to_result() + .unwrap(); + + let params = get_query_params(&result); + // Verify URL encoding is preserved + assert!(params.values().any(|v| v.contains(" ") || v.contains("&"))); + } + + #[test] + fn test_group_by_with_missing_key() { + let req1 = create_request_with_params(&[("id", "1"), ("data", "test")]); + let req2 = create_request_with_params(&[("other", "2"), ("data", "test2")]); + let requests = vec![&req1, &req2]; + let base_request = create_base_request(); + + let result = QueryBatching::new(&requests, Some("missing_key")) + .transform(base_request) + .to_result() + .unwrap(); + + // Should have no query parameters since grouped key doesn't exist + assert!(result.url().query().is_none()); + } +}