Skip to content

Commit

Permalink
Add ability to set request specific redirect policy
Browse files Browse the repository at this point in the history
A request's redirect policy will override the redirect policy for that particular request. Similar to a timeout for a specific request, it overrides the redirect policy configured by the client
  • Loading branch information
Greg Kuruc committed Mar 5, 2021
1 parent a856638 commit eea9b4b
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 9 deletions.
17 changes: 12 additions & 5 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ impl Client {
}

pub(super) fn execute_request(&self, req: Request) -> Pending {
let (method, url, mut headers, body, timeout) = req.pieces();
let (method, url, mut headers, body, timeout, redirect_policy) = req.pieces();
if url.scheme() != "http" && url.scheme() != "https" {
return Pending::new_err(error::url_bad_scheme(url));
}
Expand Down Expand Up @@ -1165,6 +1165,7 @@ impl Client {

in_flight,
timeout,
redirect_policy,
}),
}
}
Expand Down Expand Up @@ -1361,6 +1362,7 @@ pin_project! {
in_flight: ResponseFuture,
#[pin]
timeout: Option<Pin<Box<Sleep>>>,
redirect_policy: Option<redirect::Policy>,
}
}

Expand Down Expand Up @@ -1503,10 +1505,15 @@ impl Future for PendingRequest {
}
let url = self.url.clone();
self.as_mut().urls().push(url);
let action = self
.client
.redirect_policy
.check(res.status(), &loc, &self.urls);

// Request specific redirect policy takes precedence
// over client redirect policy
let policy = match &self.redirect_policy {
Some(p) => p,
None => &self.client.redirect_policy,
};

let action = policy.check(res.status(), &loc, &self.urls);

match action {
redirect::ActionKind::Follow => {
Expand Down
49 changes: 46 additions & 3 deletions src/async_impl/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use super::response::Response;
#[cfg(feature = "multipart")]
use crate::header::CONTENT_LENGTH;
use crate::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE};
use crate::{Method, Url};
use crate::{redirect, Method, Url};
use http::{request::Parts, Request as HttpRequest};

/// A request which can be executed with `Client::execute()`.
Expand All @@ -27,6 +27,7 @@ pub struct Request {
headers: HeaderMap,
body: Option<Body>,
timeout: Option<Duration>,
redirect_policy: Option<redirect::Policy>,
}

/// A builder to construct the properties of a `Request`.
Expand All @@ -48,6 +49,7 @@ impl Request {
headers: HeaderMap::new(),
body: None,
timeout: None,
redirect_policy: None,
}
}

Expand Down Expand Up @@ -111,6 +113,18 @@ impl Request {
&mut self.timeout
}

/// Get the redirect policy.
#[inline]
pub fn redirect_policy(&self) -> Option<&redirect::Policy> {
self.redirect_policy.as_ref()
}

/// Get a mutable reference to the redirect policy.
#[inline]
pub fn redirect_policy_mut(&mut self) -> &mut Option<redirect::Policy> {
&mut self.redirect_policy
}

/// Attempt to clone the request.
///
/// `None` is returned if the request can not be cloned, i.e. if the body is a stream.
Expand All @@ -122,12 +136,29 @@ impl Request {
let mut req = Request::new(self.method().clone(), self.url().clone());
*req.timeout_mut() = self.timeout().cloned();
*req.headers_mut() = self.headers().clone();
*req.redirect_policy_mut() = self.redirect_policy().cloned();
req.body = body;
Some(req)
}

pub(super) fn pieces(self) -> (Method, Url, HeaderMap, Option<Body>, Option<Duration>) {
(self.method, self.url, self.headers, self.body, self.timeout)
pub(super) fn pieces(
self,
) -> (
Method,
Url,
HeaderMap,
Option<Body>,
Option<Duration>,
Option<redirect::Policy>,
) {
(
self.method,
self.url,
self.headers,
self.body,
self.timeout,
self.redirect_policy,
)
}
}

Expand Down Expand Up @@ -244,6 +275,17 @@ impl RequestBuilder {
self
}

/// Enables a request specific redirect policy.
///
/// It affects only this request and overrides
/// the request policy configured using `ClientBuilder::request()`.
pub fn redirect(mut self, policy: redirect::Policy) -> RequestBuilder {
if let Ok(ref mut req) = self.request {
*req.redirect_policy_mut() = Some(policy);
}
self
}

/// Sends a multipart/form-data body.
///
/// ```
Expand Down Expand Up @@ -525,6 +567,7 @@ where
headers,
body: Some(body.into()),
timeout: None,
redirect_policy: None,
})
}
}
Expand Down
25 changes: 24 additions & 1 deletion src/blocking/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use super::body::{self, Body};
use super::multipart;
use super::Client;
use crate::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE};
use crate::{async_impl, Method, Url};
use crate::{async_impl, redirect, Method, Url};

/// A request which can be executed with `Client::execute()`.
pub struct Request {
Expand Down Expand Up @@ -102,6 +102,18 @@ impl Request {
self.inner.timeout_mut()
}

/// Get the redirect policy.
#[inline]
pub fn redirect_policy(&self) -> Option<&redirect::Policy> {
self.inner.redirect_policy()
}

/// Get a mutable reference to the redirect policy.
#[inline]
pub fn redirect_policy_mut(&mut self) -> &mut Option<redirect::Policy> {
self.inner.redirect_policy_mut()
}

/// Attempts to clone the `Request`.
///
/// None is returned if a body is which can not be cloned. This can be because the body is a
Expand Down Expand Up @@ -342,6 +354,17 @@ impl RequestBuilder {
self
}

/// Enables a request specific redirect policy.
///
/// It affects only this request and overrides
/// the request policy configured using `ClientBuilder::request()`.
pub fn redirect(mut self, policy: redirect::Policy) -> RequestBuilder {
if let Ok(ref mut req) = self.request {
*req.redirect_policy_mut() = Some(policy);
}
self
}

/// Modify the query string of the URL.
///
/// Modifies the URL of this request, adding the parameters provided.
Expand Down
11 changes: 11 additions & 0 deletions src/redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::Url;
/// the allowed maximum redirect hops in a chain.
/// - `none` can be used to disable all redirect behavior.
/// - `custom` can be used to create a customized policy.
#[derive(Clone)]
pub struct Policy {
inner: PolicyKind,
}
Expand Down Expand Up @@ -209,6 +210,16 @@ enum PolicyKind {
None,
}

impl Clone for PolicyKind {
fn clone(&self) -> Self {
match self {
c @ PolicyKind::Custom(_) => c.clone(),
l @ PolicyKind::Limit(_) => l.clone(),
PolicyKind::None => PolicyKind::None,
}
}
}

impl fmt::Debug for Policy {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("Policy").field(&self.inner).finish()
Expand Down
61 changes: 61 additions & 0 deletions tests/redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,64 @@ async fn test_redirect_302_with_set_cookies() {
assert_eq!(res.url().as_str(), dst);
assert_eq!(res.status(), reqwest::StatusCode::OK);
}

#[tokio::test]
async fn test_request_redirect() {
let code = 301u16;

let redirect = server::http(move |req| async move {
if req.method() == "POST" {
assert_eq!(req.uri(), &*format!("/{}", code));
http::Response::builder()
.status(code)
.header("location", "/dst")
.header("server", "test-redirect")
.body(Default::default())
.unwrap()
} else {
assert_eq!(req.method(), "GET");

http::Response::builder()
.header("server", "test-dst")
.body(Default::default())
.unwrap()
}
});

let url = format!("http://{}/{}", redirect.addr(), code);
let dst = format!("http://{}/{}", redirect.addr(), "dst");

let default_redirect_client = reqwest::Client::new();
let res = default_redirect_client
.request(reqwest::Method::POST, &url)
.redirect(reqwest::redirect::Policy::none())
.send()
.await
.unwrap();

assert_eq!(res.url().as_str(), url);
assert_eq!(res.status(), reqwest::StatusCode::MOVED_PERMANENTLY);
assert_eq!(
res.headers().get(reqwest::header::SERVER).unwrap(),
&"test-redirect"
);

let no_redirect_client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap();

let res = no_redirect_client
.request(reqwest::Method::POST, &url)
.redirect(reqwest::redirect::Policy::limited(2))
.send()
.await
.unwrap();

assert_eq!(res.url().as_str(), dst);
assert_eq!(res.status(), reqwest::StatusCode::OK);
assert_eq!(
res.headers().get(reqwest::header::SERVER).unwrap(),
&"test-dst"
);
}

0 comments on commit eea9b4b

Please sign in to comment.