Skip to content

Commit

Permalink
Re-implement per-request redirect configuration
Browse files Browse the repository at this point in the history
ref: seanmonstar#1204

co-authored-by: Greg Kuruc <[email protected]>
  • Loading branch information
mikayla-maki and Greg Kuruc committed Oct 1, 2024
1 parent cf69fd4 commit fe78f10
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 3 deletions.
8 changes: 6 additions & 2 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1807,7 +1807,7 @@ impl Client {
}

pub(super) fn execute_request(&self, req: Request) -> Pending {
let (method, url, mut headers, body, timeout, version) = req.pieces();
let (method, url, mut headers, body, timeout, redirect_policy, version) = req.pieces();
if url.scheme() != "http" && url.scheme() != "https" {
return Pending::new_err(error::url_bad_scheme(url));
}
Expand Down Expand Up @@ -1894,6 +1894,7 @@ impl Client {
urls: Vec::new(),

retry_count: 0,
redirect_policy,

client: self.inner.clone(),

Expand Down Expand Up @@ -2159,6 +2160,7 @@ pin_project! {
urls: Vec<Url>,

retry_count: usize,
redirect_policy: Option<redirect::Policy>,

client: Arc<ClientRef>,

Expand Down Expand Up @@ -2415,9 +2417,11 @@ impl Future for PendingRequest {
}
let url = self.url.clone();
self.as_mut().urls().push(url);
// This request's redirect policy overrides the client's redirect policy
let action = self
.client
.redirect_policy
.as_ref()
.unwrap_or(&self.client.redirect_policy)
.check(res.status(), &loc, &self.urls);

match action {
Expand Down
27 changes: 26 additions & 1 deletion src/async_impl/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use super::response::Response;
#[cfg(feature = "multipart")]
use crate::header::CONTENT_LENGTH;
use crate::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE};
use crate::{Method, Url};
use crate::{redirect, Method, Url};
use http::{request::Parts, Request as HttpRequest, Version};

/// A request which can be executed with `Client::execute()`.
Expand All @@ -25,6 +25,7 @@ pub struct Request {
headers: HeaderMap,
body: Option<Body>,
timeout: Option<Duration>,
redirect_policy: Option<redirect::Policy>,
version: Version,
}

Expand All @@ -47,6 +48,7 @@ impl Request {
headers: HeaderMap::new(),
body: None,
timeout: None,
redirect_policy: None,
version: Version::default(),
}
}
Expand Down Expand Up @@ -111,6 +113,18 @@ impl Request {
&mut self.timeout
}

/// Get a this request's redirect policy.
#[inline]
pub fn redirect_policy(&self) -> Option<&redirect::Policy> {
self.redirect_policy.as_ref()
}

/// Get a mutable reference to the redirect policy.
#[inline]
pub fn redirect_policy_mut(&mut self) -> &mut Option<redirect::Policy> {
&mut self.redirect_policy
}

/// Get the http version.
#[inline]
pub fn version(&self) -> Version {
Expand Down Expand Up @@ -147,6 +161,7 @@ impl Request {
HeaderMap,
Option<Body>,
Option<Duration>,
Option<redirect::Policy>,
Version,
) {
(
Expand All @@ -155,6 +170,7 @@ impl Request {
self.headers,
self.body,
self.timeout,
self.redirect_policy,
self.version,
)
}
Expand Down Expand Up @@ -290,6 +306,14 @@ impl RequestBuilder {
self
}

/// Overrides the client's redirect policy for this request
pub fn redirect(mut self, policy: redirect::Policy) -> RequestBuilder {
if let Ok(ref mut req) = self.request {
*req.redirect_policy_mut() = Some(policy)
}
self
}

/// Sends a multipart/form-data body.
///
/// ```
Expand Down Expand Up @@ -617,6 +641,7 @@ where
headers,
body: Some(body.into()),
timeout: None,
redirect_policy: None,
version,
})
}
Expand Down
62 changes: 62 additions & 0 deletions tests/redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,65 @@ async fn test_redirect_https_only_enforced_gh1312() {
let err = res.unwrap_err();
assert!(err.is_redirect());
}

// Code taken from: https://github.com/seanmonstar/reqwest/pull/1204
#[tokio::test]
async fn test_request_redirect() {
let code = 301u16;

let redirect = server::http(move |req| async move {
if req.method() == "POST" {
assert_eq!(req.uri(), &*format!("/{}", code));
http::Response::builder()
.status(code)
.header("location", "/dst")
.header("server", "test-redirect")
.body(Default::default())
.unwrap()
} else {
assert_eq!(req.method(), "GET");

http::Response::builder()
.header("server", "test-dst")
.body(Default::default())
.unwrap()
}
});

let url = format!("http://{}/{}", redirect.addr(), code);
let dst = format!("http://{}/{}", redirect.addr(), "dst");

let default_redirect_client = reqwest::Client::new();
let res = default_redirect_client
.request(reqwest::Method::POST, &url)
.redirect(reqwest::redirect::Policy::none())
.send()
.await
.unwrap();

assert_eq!(res.url().as_str(), url);
assert_eq!(res.status(), reqwest::StatusCode::MOVED_PERMANENTLY);
assert_eq!(
res.headers().get(reqwest::header::SERVER).unwrap(),
&"test-redirect"
);

let no_redirect_client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap();

let res = no_redirect_client
.request(reqwest::Method::POST, &url)
.redirect(reqwest::redirect::Policy::limited(2))
.send()
.await
.unwrap();

assert_eq!(res.url().as_str(), dst);
assert_eq!(res.status(), reqwest::StatusCode::OK);
assert_eq!(
res.headers().get(reqwest::header::SERVER).unwrap(),
&"test-dst"
);
}

0 comments on commit fe78f10

Please sign in to comment.