From 1d10b561f491d8b671d100a44749022e7ec5fce6 Mon Sep 17 00:00:00 2001 From: jeadie Date: Wed, 4 Sep 2024 08:48:43 +1000 Subject: [PATCH] Revert "Revert "Add a request pipeline and configuration for retries (#469)"" This reverts commit 37e4eeff568b46caaf37b7591ce5e0f71fc7c4cb. --- Cargo.toml | 2 +- examples/client_configuration.rs | 24 ++- .../client_credentials_secret.rs | 2 +- graph-core/Cargo.toml | 1 + graph-error/src/graph_failure.rs | 14 ++ graph-http/Cargo.toml | 4 +- graph-http/src/client.rs | 101 +++++++++++- graph-http/src/lib.rs | 2 + graph-http/src/request_handler.rs | 18 ++- graph-http/src/tower_services.rs | 146 ++++++++++++++++++ graph-oauth/Cargo.toml | 1 + 11 files changed, 305 insertions(+), 10 deletions(-) create mode 100644 graph-http/src/tower_services.rs diff --git a/Cargo.toml b/Cargo.toml index bcd39cbf..cd80cd1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,7 +56,7 @@ rustls-tls = ["reqwest/rustls-tls", "graph-http/rustls-tls", "graph-oauth/rustls brotli = ["reqwest/brotli", "graph-http/brotli", "graph-oauth/brotli", "graph-core/brotli"] deflate = ["reqwest/deflate", "graph-http/deflate", "graph-oauth/deflate", "graph-core/deflate"] trust-dns = ["reqwest/trust-dns", "graph-http/trust-dns", "graph-oauth/trust-dns", "graph-core/trust-dns"] -socks = ["graph-http/socks"] +socks = ["reqwest/socks", "graph-http/socks", "graph-oauth/socks", "graph-core/socks"] openssl = ["graph-oauth/openssl"] interactive-auth = ["graph-oauth/interactive-auth"] test-util = ["graph-http/test-util"] diff --git a/examples/client_configuration.rs b/examples/client_configuration.rs index bc1675b6..7c8b1202 100644 --- a/examples/client_configuration.rs +++ b/examples/client_configuration.rs @@ -1,4 +1,5 @@ #![allow(dead_code, unused, unused_imports, clippy::module_inception)] +use graph_oauth::ConfidentialClientApplication; use graph_rs_sdk::{header::HeaderMap, header::HeaderValue, GraphClient, GraphClientConfiguration}; use http::header::ACCEPT; use http::HeaderName; @@ -10,7 +11,28 @@ fn main() { let client_config = GraphClientConfiguration::new() .access_token(ACCESS_TOKEN) .timeout(Duration::from_secs(30)) - .default_headers(HeaderMap::default()); + .default_headers(HeaderMap::default()) + .retry(Some(10)) // retry 10 times if the request is not successful + .concurrency_limit(Some(10)) // limit the number of concurrent requests on this client to 10 + .wait_for_retry_after_headers(true); // wait the amount of seconds specified by the Retry-After header of the response when we reach the throttling limits (429 Too Many Requests) + + let _ = GraphClient::from(client_config); +} + +// Using Identity Platform Clients +fn configure_graph_client(client_id: &str, client_secret: &str, tenant: &str) { + let mut confidential_client_application = ConfidentialClientApplication::builder(client_id) + .with_client_secret(client_secret) + .with_tenant(tenant) + .build(); + + let client_config = GraphClientConfiguration::new() + .client_application(confidential_client_application) + .timeout(Duration::from_secs(30)) + .default_headers(HeaderMap::default()) + .retry(Some(10)) // retry 10 times if the request is not successful + .concurrency_limit(Some(10)) // limit the number of concurrent requests on this client to 10 + .wait_for_retry_after_headers(true); // wait the amount of seconds specified by the Retry-After header of the response when we reach the throttling limits (429 Too Many Requests) let _ = GraphClient::from(client_config); } diff --git a/examples/identity_platform_auth/client_credentials/client_credentials_secret.rs b/examples/identity_platform_auth/client_credentials/client_credentials_secret.rs index 56ca474a..e676d78c 100644 --- a/examples/identity_platform_auth/client_credentials/client_credentials_secret.rs +++ b/examples/identity_platform_auth/client_credentials/client_credentials_secret.rs @@ -5,7 +5,7 @@ use graph_rs_sdk::{identity::ConfidentialClientApplication, GraphClient}; -pub async fn build_client(client_id: &str, client_secret: &str, tenant: &str) -> GraphClient { +pub fn build_client(client_id: &str, client_secret: &str, tenant: &str) -> GraphClient { let mut confidential_client_application = ConfidentialClientApplication::builder(client_id) .with_client_secret(client_secret) .with_tenant(tenant) diff --git a/graph-core/Cargo.toml b/graph-core/Cargo.toml index 9bb9662e..41596856 100644 --- a/graph-core/Cargo.toml +++ b/graph-core/Cargo.toml @@ -36,3 +36,4 @@ rustls-tls = ["reqwest/rustls-tls"] brotli = ["reqwest/brotli"] deflate = ["reqwest/deflate"] trust-dns = ["reqwest/trust-dns"] +socks = ["reqwest/socks"] diff --git a/graph-error/src/graph_failure.rs b/graph-error/src/graph_failure.rs index 314762b9..954b3e46 100644 --- a/graph-error/src/graph_failure.rs +++ b/graph-error/src/graph_failure.rs @@ -3,8 +3,10 @@ use crate::internal::GraphRsError; use crate::{AuthExecutionError, AuthorizationFailure, ErrorMessage}; use reqwest::header::HeaderMap; use std::cell::BorrowMutError; +use std::error::Error; use std::io; use std::io::ErrorKind; +use std::num::ParseIntError; use std::str::Utf8Error; use std::sync::mpsc; @@ -74,6 +76,12 @@ pub enum GraphFailure { #[error("{0:#?}")] ErrorMessage(#[from] ErrorMessage), + #[error("Temporary Graph API Error")] + TemporaryError, + + #[error("Parse Int error:\n{0:#?}")] + ParseIntError(#[from] ParseIntError), + #[error("message: {0:#?}, response: {1:#?}", message, response)] SilentTokenAuth { message: String, @@ -160,3 +168,9 @@ impl From for GraphFailure { } } } + +impl From> for GraphFailure { + fn from(value: Box) -> Self { + value.into() + } +} diff --git a/graph-http/Cargo.toml b/graph-http/Cargo.toml index 1dce4f38..14528f91 100644 --- a/graph-http/Cargo.toml +++ b/graph-http/Cargo.toml @@ -23,6 +23,8 @@ serde_urlencoded = "0.7.1" thiserror = "1" tokio = { version = "1.27.0", features = ["full", "tracing"] } url = { version = "2", features = ["serde"] } +tower = { version = "0.4.13", features = ["limit", "retry", "timeout", "util"] } +futures-util = "0.3.30" graph-error = { path = "../graph-error" } graph-core = { path = "../graph-core", default-features = false } @@ -34,5 +36,5 @@ rustls-tls = ["reqwest/rustls-tls", "graph-core/rustls-tls"] brotli = ["reqwest/brotli", "graph-core/brotli"] deflate = ["reqwest/deflate", "graph-core/deflate"] trust-dns = ["reqwest/trust-dns", "graph-core/trust-dns"] -socks = ["reqwest/socks"] +socks = ["reqwest/socks", "graph-core/socks"] test-util = [] diff --git a/graph-http/src/client.rs b/graph-http/src/client.rs index 1c3f4150..0f4f298a 100644 --- a/graph-http/src/client.rs +++ b/graph-http/src/client.rs @@ -4,16 +4,28 @@ use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, USER_AGENT}; use reqwest::redirect::Policy; use reqwest::tls::Version; use reqwest::Proxy; +use reqwest::{Request, Response}; use std::env::VarError; use std::ffi::OsStr; use std::fmt::{Debug, Formatter}; use std::time::Duration; +use tower::limit::ConcurrencyLimitLayer; +use tower::retry::RetryLayer; +use tower::util::BoxCloneService; +use tower::ServiceExt; fn user_agent_header_from_env() -> Option { let header = std::option_env!("GRAPH_CLIENT_USER_AGENT")?; HeaderValue::from_str(header).ok() } +#[derive(Default, Clone)] +struct ServiceLayersConfiguration { + concurrency_limit: Option, + retry: Option, + wait_for_retry_after_headers: Option<()>, +} + #[derive(Clone)] struct ClientConfiguration { client_application: Option>, @@ -26,6 +38,7 @@ struct ClientConfiguration { /// TLS 1.2 required to support all features in Microsoft Graph /// See [Reliability and Support](https://learn.microsoft.com/en-us/graph/best-practices-concept#reliability-and-support) min_tls_version: Version, + service_layers_configuration: ServiceLayersConfiguration, proxy: Option, } @@ -47,6 +60,7 @@ impl ClientConfiguration { connection_verbose: false, https_only: true, min_tls_version: Version::TLS_1_2, + service_layers_configuration: ServiceLayersConfiguration::default(), proxy: None, } } @@ -164,6 +178,55 @@ impl GraphClientConfiguration { self } + /// Enable a request retry for a failed request. The retry parameter can be used to + /// change how many times the request should be retried. + /// + /// Some requests may fail on GraphAPI side and should be retried. + /// Only server errors (HTTP code between 500 and 599) will be retried. + /// + /// Default is no retry. + pub fn retry(mut self, retry: Option) -> GraphClientConfiguration { + self.config.service_layers_configuration.retry = retry; + self + } + + /// Enable a request retry if we reach the throttling limits and GraphAPI returns a + /// 429 Too Many Requests with a Retry-After header + /// + /// Retry attempts are executed when the response has a status code of 429, 500, 503, 504 + /// and the response has a Retry-After header. The Retry-After header provides a back-off + /// time to wait for before retrying the request again. + /// + /// Be careful with this parameter as some API endpoints have quite + /// low limits (reports for example) and the request may hang for hundreds of seconds. + /// For maximum throughput you may want to not respect the Retry-After header as hitting + /// another server thanks to load-balancing may lead to a successful response. + /// + /// Default is no retry. + pub fn wait_for_retry_after_headers(mut self, retry: bool) -> GraphClientConfiguration { + self.config + .service_layers_configuration + .wait_for_retry_after_headers = match retry { + true => Some(()), + false => None, + }; + self + } + + /// Enable a concurrency limit on the client. + /// + /// Every request through this client will be subject to a concurrency limit. + /// Can be useful to stay under the API limits set by GraphAPI. + /// + /// Default is no concurrency limit. + pub fn concurrency_limit( + mut self, + concurrency_limit: Option, + ) -> GraphClientConfiguration { + self.config.service_layers_configuration.concurrency_limit = concurrency_limit; + self + } + pub fn build(self) -> Client { let config = self.clone(); let headers = self.config.headers.clone(); @@ -187,19 +250,45 @@ impl GraphClientConfiguration { builder = builder.proxy(proxy); } + let client = builder.build().unwrap(); + + let service = tower::ServiceBuilder::new() + .option_layer( + self.config + .service_layers_configuration + .retry + .map(|num| RetryLayer::new(crate::tower_services::Attempts(num))), + ) + .option_layer( + self.config + .service_layers_configuration + .wait_for_retry_after_headers + .map(|_| RetryLayer::new(crate::tower_services::WaitFor())), + ) + .option_layer( + self.config + .service_layers_configuration + .concurrency_limit + .map(ConcurrencyLimitLayer::new), + ) + .service(client.clone()) + .boxed_clone(); + if let Some(client_application) = self.config.client_application { Client { client_application, - inner: builder.build().unwrap(), + inner: client, headers, builder: config, + service, } } else { Client { client_application: Box::::default(), - inner: builder.build().unwrap(), + inner: client, headers, builder: config, + service, } } } @@ -226,16 +315,18 @@ impl GraphClientConfiguration { builder = builder.proxy(proxy); } + let client = builder.build().unwrap(); + if let Some(client_application) = self.config.client_application { BlockingClient { client_application, - inner: builder.build().unwrap(), + inner: client, headers, } } else { BlockingClient { client_application: Box::::default(), - inner: builder.build().unwrap(), + inner: client, headers, } } @@ -254,6 +345,8 @@ pub struct Client { pub(crate) inner: reqwest::Client, pub(crate) headers: HeaderMap, pub(crate) builder: GraphClientConfiguration, + pub(crate) service: + BoxCloneService>, } impl Client { diff --git a/graph-http/src/lib.rs b/graph-http/src/lib.rs index d89846f5..c9e99e58 100644 --- a/graph-http/src/lib.rs +++ b/graph-http/src/lib.rs @@ -7,6 +7,7 @@ mod core; mod request_components; mod request_handler; mod resource_identifier; +mod tower_services; mod upload_session; pub mod url; @@ -27,6 +28,7 @@ pub(crate) mod internal { pub use crate::request_handler::*; #[allow(unused_imports)] pub use crate::resource_identifier::*; + pub use crate::tower_services::*; pub use crate::traits::*; pub use crate::upload_session::*; pub use graph_core::http::*; diff --git a/graph-http/src/request_handler.rs b/graph-http/src/request_handler.rs index 84355a25..c6e708e5 100644 --- a/graph-http/src/request_handler.rs +++ b/graph-http/src/request_handler.rs @@ -7,19 +7,23 @@ use async_stream::try_stream; use futures::Stream; use graph_error::{AuthExecutionResult, ErrorMessage, GraphFailure, GraphResult}; use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE}; +use reqwest::{Request, Response}; use serde::de::DeserializeOwned; use std::collections::VecDeque; use std::fmt::Debug; use std::time::Duration; +use tower::util::BoxCloneService; +use tower::{Service, ServiceExt}; use url::Url; -#[derive(Default)] pub struct RequestHandler { pub(crate) inner: Client, pub(crate) request_components: RequestComponents, pub(crate) error: Option, pub(crate) body: Option, pub(crate) client_builder: GraphClientConfiguration, + pub(crate) service: + BoxCloneService>, } impl RequestHandler { @@ -29,6 +33,7 @@ impl RequestHandler { err: Option, body: Option, ) -> RequestHandler { + let service = inner.service.clone(); let client_builder = inner.builder.clone(); let mut original_headers = inner.headers.clone(); original_headers.extend(request_components.headers.clone()); @@ -51,6 +56,7 @@ impl RequestHandler { error, body, client_builder, + service, } } @@ -242,8 +248,16 @@ impl RequestHandler { #[inline] pub async fn send(self) -> GraphResult { + let mut service = self.service.clone(); let request_builder = self.build().await?; - request_builder.send().await.map_err(GraphFailure::from) + let request = request_builder.build()?; + service + .ready() + .await + .map_err(GraphFailure::from)? + .call(request) + .await + .map_err(GraphFailure::from) } } diff --git a/graph-http/src/tower_services.rs b/graph-http/src/tower_services.rs new file mode 100644 index 00000000..2f4e8a50 --- /dev/null +++ b/graph-http/src/tower_services.rs @@ -0,0 +1,146 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; +use std::{sync::Mutex, task::Waker, thread}; + +use futures_util::future; +use http::StatusCode; +use reqwest::{Request, Response}; + +#[derive(Clone)] +pub(crate) struct Attempts(pub usize); + +impl tower::retry::Policy> + for Attempts +{ + type Future = future::Ready; + + fn retry( + &self, + _req: &Request, + result: Result<&Response, &Box<(dyn std::error::Error + Send + Sync + 'static)>>, + ) -> Option { + match result { + Ok(response) => { + if response.status().is_server_error() && self.0 > 0 { + return Some(future::ready(Attempts(self.0 - 1))); + } + None + } + Err(_) => { + if self.0 > 0 { + Some(future::ready(Attempts(self.0 - 1))) + } else { + None + } + } + } + } + + fn clone_request(&self, req: &Request) -> Option { + req.try_clone() + } +} + +#[derive(Clone)] +pub(crate) struct WaitFor(); + +impl tower::retry::Policy> + for WaitFor +{ + type Future = future::Either, WaitBeforeRetry>; + + fn retry( + &self, + _req: &Request, + result: Result<&Response, &Box<(dyn std::error::Error + Send + Sync + 'static)>>, + ) -> Option { + match result { + Ok(response) => match response.status() { + StatusCode::TOO_MANY_REQUESTS + | StatusCode::INTERNAL_SERVER_ERROR + | StatusCode::SERVICE_UNAVAILABLE + | StatusCode::GATEWAY_TIMEOUT => match response.headers().get("Retry-After") { + Some(retry_after) => match retry_after.to_str() { + Ok(ra) => match ra.parse::() { + Ok(retry_after) => { + let sleep = WaitBeforeRetry::new( + Some(WaitFor()), + Duration::from_secs(retry_after), + ); + Some(future::Either::Right(sleep)) + } + Err(_) => None, + }, + Err(_) => None, + }, + None => None, + }, + _ => None, + }, + Err(_) => None, + } + } + + fn clone_request(&self, req: &Request) -> Option { + req.try_clone() + } +} + +pub struct WaitBeforeRetry { + inner: Option, + shared_state: Arc>, +} + +struct SharedState { + completed: bool, + waker: Option, +} + +impl WaitBeforeRetry { + pub fn new(inner: Option, duration: Duration) -> Self { + let shared_state = Arc::new(Mutex::new(SharedState { + completed: false, + waker: None, + })); + + // Spawn the new thread + let thread_shared_state = shared_state.clone(); + thread::spawn(move || { + thread::sleep(duration); + let mut shared_state = thread_shared_state.lock().unwrap(); + // Signal that the timer has completed and wake up the last + // task on which the future was polled, if one exists. + shared_state.completed = true; + if let Some(waker) = shared_state.waker.take() { + waker.wake() + } + }); + + WaitBeforeRetry { + inner, + shared_state, + } + } +} + +impl Unpin for WaitBeforeRetry {} + +impl Future for WaitBeforeRetry { + type Output = T; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + { + let mut shared_state = self.shared_state.lock().unwrap(); + if !shared_state.completed { + shared_state.waker = Some(cx.waker().clone()); + return Poll::Pending; + } + } + + Poll::Ready(self.inner.take().expect("Ready polled after completion")) + } +} diff --git a/graph-oauth/Cargo.toml b/graph-oauth/Cargo.toml index 764cc77a..c100338c 100644 --- a/graph-oauth/Cargo.toml +++ b/graph-oauth/Cargo.toml @@ -50,6 +50,7 @@ rustls-tls = ["reqwest/rustls-tls", "graph-core/rustls-tls"] brotli = ["reqwest/brotli", "graph-core/brotli"] deflate = ["reqwest/deflate", "graph-core/deflate"] trust-dns = ["reqwest/trust-dns", "graph-core/trust-dns"] +socks = ["reqwest/socks", "graph-core/socks"] openssl = ["dep:openssl"] interactive-auth = ["dep:wry", "dep:tao"]