diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 7b21db849..c14ca4d4a 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -1088,7 +1088,7 @@ impl Client { } pub(super) fn execute_request(&self, req: Request) -> Pending { - let (method, url, mut headers, body, timeout) = req.pieces(); + let (method, url, mut headers, body, timeout, redirect_policy) = req.pieces(); if url.scheme() != "http" && url.scheme() != "https" { return Pending::new_err(error::url_bad_scheme(url)); } @@ -1165,6 +1165,7 @@ impl Client { in_flight, timeout, + redirect_policy, }), } } @@ -1361,6 +1362,7 @@ pin_project! { in_flight: ResponseFuture, #[pin] timeout: Option>>, + redirect_policy: Option, } } @@ -1503,10 +1505,15 @@ impl Future for PendingRequest { } let url = self.url.clone(); self.as_mut().urls().push(url); - let action = self - .client - .redirect_policy - .check(res.status(), &loc, &self.urls); + + // Request specific redirect policy takes precedence + // over client redirect policy + let policy = match &self.redirect_policy { + Some(p) => p, + None => &self.client.redirect_policy + }; + + let action = policy.check(res.status(), &loc, &self.urls); match action { redirect::ActionKind::Follow => { diff --git a/src/async_impl/request.rs b/src/async_impl/request.rs index 3e089e960..e027caf4c 100644 --- a/src/async_impl/request.rs +++ b/src/async_impl/request.rs @@ -17,7 +17,7 @@ use super::response::Response; #[cfg(feature = "multipart")] use crate::header::CONTENT_LENGTH; use crate::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE}; -use crate::{Method, Url}; +use crate::{Method, Url, redirect}; use http::{request::Parts, Request as HttpRequest}; /// A request which can be executed with `Client::execute()`. @@ -27,6 +27,7 @@ pub struct Request { headers: HeaderMap, body: Option, timeout: Option, + redirect_policy: Option, } /// A builder to construct the properties of a `Request`. @@ -48,6 +49,7 @@ impl Request { headers: HeaderMap::new(), body: None, timeout: None, + redirect_policy: None, } } @@ -111,6 +113,18 @@ impl Request { &mut self.timeout } + /// Get the redirect policy. + #[inline] + pub fn redirect_policy(&self) -> Option<&redirect::Policy> { + self.redirect_policy.as_ref() + } + + /// Get a mutable reference to the redirect policy. + #[inline] + pub fn redirect_policy_mut(&mut self) -> &mut Option { + &mut self.redirect_policy + } + /// Attempt to clone the request. /// /// `None` is returned if the request can not be cloned, i.e. if the body is a stream. @@ -122,12 +136,13 @@ impl Request { let mut req = Request::new(self.method().clone(), self.url().clone()); *req.timeout_mut() = self.timeout().cloned(); *req.headers_mut() = self.headers().clone(); + *req.redirect_policy_mut() = self.redirect_policy().cloned(); req.body = body; Some(req) } - pub(super) fn pieces(self) -> (Method, Url, HeaderMap, Option, Option) { - (self.method, self.url, self.headers, self.body, self.timeout) + pub(super) fn pieces(self) -> (Method, Url, HeaderMap, Option, Option, Option) { + (self.method, self.url, self.headers, self.body, self.timeout, self.redirect_policy) } } @@ -244,6 +259,17 @@ impl RequestBuilder { self } + /// Enables a request specific redirect policy. + /// + /// It affects only this request and overrides + /// the request policy configured using `ClientBuilder::request()`. + pub fn redirect(mut self, policy: redirect::Policy) -> RequestBuilder { + if let Ok(ref mut req) = self.request { + *req.redirect_policy_mut() = Some(policy); + } + self + } + /// Sends a multipart/form-data body. /// /// ``` @@ -525,6 +551,7 @@ where headers, body: Some(body.into()), timeout: None, + redirect_policy: None, }) } } diff --git a/src/blocking/request.rs b/src/blocking/request.rs index 7b36d3ffa..7624a1999 100644 --- a/src/blocking/request.rs +++ b/src/blocking/request.rs @@ -102,6 +102,18 @@ impl Request { self.inner.timeout_mut() } + /// Get the redirect policy. + #[inline] + pub fn redirect_policy(&self) -> Option<&redirect::Policy> { + self.inner.redirect_policy() + } + + /// Get a mutable reference to the redirect policy. + #[inline] + pub fn redirect_policy_mut(&mut self) -> &mut Option { + self.inner.redirect_policy_mut() + } + /// Attempts to clone the `Request`. /// /// None is returned if a body is which can not be cloned. This can be because the body is a @@ -342,6 +354,17 @@ impl RequestBuilder { self } + /// Enables a request specific redirect policy. + /// + /// It affects only this request and overrides + /// the request policy configured using `ClientBuilder::request()`. + pub fn redirect(mut self, policy: redirect::Policy) -> RequestBuilder { + if let Ok(ref mut req) = self.request { + *req.redirect_policy_mut() = Some(policy); + } + self + } + /// Modify the query string of the URL. /// /// Modifies the URL of this request, adding the parameters provided. diff --git a/src/redirect.rs b/src/redirect.rs index a0722a6bc..613d83210 100644 --- a/src/redirect.rs +++ b/src/redirect.rs @@ -21,6 +21,7 @@ use crate::Url; /// the allowed maximum redirect hops in a chain. /// - `none` can be used to disable all redirect behavior. /// - `custom` can be used to create a customized policy. +#[derive(Clone)] pub struct Policy { inner: PolicyKind, } @@ -209,6 +210,16 @@ enum PolicyKind { None, } +impl Clone for PolicyKind { + fn clone(&self) -> Self { + match self { + c @ PolicyKind::Custom(_) => c.clone(), + l @ PolicyKind::Limit(_) => l.clone(), + PolicyKind::None => PolicyKind::None + } + } +} + impl fmt::Debug for Policy { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_tuple("Policy").field(&self.inner).finish() diff --git a/tests/redirect.rs b/tests/redirect.rs index 16f7712f5..5f5e9eab5 100644 --- a/tests/redirect.rs +++ b/tests/redirect.rs @@ -316,3 +316,64 @@ async fn test_redirect_302_with_set_cookies() { assert_eq!(res.url().as_str(), dst); assert_eq!(res.status(), reqwest::StatusCode::OK); } + +#[tokio::test] +async fn test_request_redirect() { + let code = 301u16; + + let redirect = server::http(move |req| async move { + if req.method() == "POST" { + assert_eq!(req.uri(), &*format!("/{}", code)); + http::Response::builder() + .status(code) + .header("location", "/dst") + .header("server", "test-redirect") + .body(Default::default()) + .unwrap() + } else { + assert_eq!(req.method(), "GET"); + + http::Response::builder() + .header("server", "test-dst") + .body(Default::default()) + .unwrap() + } + }); + + let url = format!("http://{}/{}", redirect.addr(), code); + let dst = format!("http://{}/{}", redirect.addr(), "dst"); + + let default_redirect_client = reqwest::Client::new(); + let res = default_redirect_client + .request(reqwest::Method::POST, &url) + .redirect(reqwest::redirect::Policy::none()) + .send() + .await + .unwrap(); + + assert_eq!(res.url().as_str(), url); + assert_eq!(res.status(), reqwest::StatusCode::MOVED_PERMANENTLY); + assert_eq!( + res.headers().get(reqwest::header::SERVER).unwrap(), + &"test-redirect" + ); + + let no_redirect_client = reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap(); + + let res = no_redirect_client + .request(reqwest::Method::POST, &url) + .redirect(reqwest::redirect::Policy::limited(2)) + .send() + .await + .unwrap(); + + assert_eq!(res.url().as_str(), dst); + assert_eq!(res.status(), reqwest::StatusCode::OK); + assert_eq!( + res.headers().get(reqwest::header::SERVER).unwrap(), + &"test-dst" + ); +}