From fe78f10dbfc78d13186b19f0858f3f05328d1d03 Mon Sep 17 00:00:00 2001 From: Mikayla Date: Mon, 30 Sep 2024 20:52:43 -0700 Subject: [PATCH] Re-implement per-request redirect configuration ref: https://github.com/seanmonstar/reqwest/pull/1204 co-authored-by: Greg Kuruc --- src/async_impl/client.rs | 8 +++-- src/async_impl/request.rs | 27 ++++++++++++++++- tests/redirect.rs | 62 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 3 deletions(-) diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index b247624ed..c72fdb3af 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -1807,7 +1807,7 @@ impl Client { } pub(super) fn execute_request(&self, req: Request) -> Pending { - let (method, url, mut headers, body, timeout, version) = req.pieces(); + let (method, url, mut headers, body, timeout, redirect_policy, version) = req.pieces(); if url.scheme() != "http" && url.scheme() != "https" { return Pending::new_err(error::url_bad_scheme(url)); } @@ -1894,6 +1894,7 @@ impl Client { urls: Vec::new(), retry_count: 0, + redirect_policy, client: self.inner.clone(), @@ -2159,6 +2160,7 @@ pin_project! { urls: Vec, retry_count: usize, + redirect_policy: Option, client: Arc, @@ -2415,9 +2417,11 @@ impl Future for PendingRequest { } let url = self.url.clone(); self.as_mut().urls().push(url); + // This request's redirect policy overrides the client's redirect policy let action = self - .client .redirect_policy + .as_ref() + .unwrap_or(&self.client.redirect_policy) .check(res.status(), &loc, &self.urls); match action { diff --git a/src/async_impl/request.rs b/src/async_impl/request.rs index 665710430..4186f57b8 100644 --- a/src/async_impl/request.rs +++ b/src/async_impl/request.rs @@ -15,7 +15,7 @@ use super::response::Response; #[cfg(feature = "multipart")] use crate::header::CONTENT_LENGTH; use crate::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE}; -use crate::{Method, Url}; +use crate::{redirect, Method, Url}; use http::{request::Parts, Request as HttpRequest, Version}; /// A request which can be executed with `Client::execute()`. @@ -25,6 +25,7 @@ pub struct Request { headers: HeaderMap, body: Option, timeout: Option, + redirect_policy: Option, version: Version, } @@ -47,6 +48,7 @@ impl Request { headers: HeaderMap::new(), body: None, timeout: None, + redirect_policy: None, version: Version::default(), } } @@ -111,6 +113,18 @@ impl Request { &mut self.timeout } + /// Get a this request's redirect policy. + #[inline] + pub fn redirect_policy(&self) -> Option<&redirect::Policy> { + self.redirect_policy.as_ref() + } + + /// Get a mutable reference to the redirect policy. + #[inline] + pub fn redirect_policy_mut(&mut self) -> &mut Option { + &mut self.redirect_policy + } + /// Get the http version. #[inline] pub fn version(&self) -> Version { @@ -147,6 +161,7 @@ impl Request { HeaderMap, Option, Option, + Option, Version, ) { ( @@ -155,6 +170,7 @@ impl Request { self.headers, self.body, self.timeout, + self.redirect_policy, self.version, ) } @@ -290,6 +306,14 @@ impl RequestBuilder { self } + /// Overrides the client's redirect policy for this request + pub fn redirect(mut self, policy: redirect::Policy) -> RequestBuilder { + if let Ok(ref mut req) = self.request { + *req.redirect_policy_mut() = Some(policy) + } + self + } + /// Sends a multipart/form-data body. /// /// ``` @@ -617,6 +641,7 @@ where headers, body: Some(body.into()), timeout: None, + redirect_policy: None, version, }) } diff --git a/tests/redirect.rs b/tests/redirect.rs index 9df6265a4..8503b5cc0 100644 --- a/tests/redirect.rs +++ b/tests/redirect.rs @@ -361,3 +361,65 @@ async fn test_redirect_https_only_enforced_gh1312() { let err = res.unwrap_err(); assert!(err.is_redirect()); } + +// Code taken from: https://github.com/seanmonstar/reqwest/pull/1204 +#[tokio::test] +async fn test_request_redirect() { + let code = 301u16; + + let redirect = server::http(move |req| async move { + if req.method() == "POST" { + assert_eq!(req.uri(), &*format!("/{}", code)); + http::Response::builder() + .status(code) + .header("location", "/dst") + .header("server", "test-redirect") + .body(Default::default()) + .unwrap() + } else { + assert_eq!(req.method(), "GET"); + + http::Response::builder() + .header("server", "test-dst") + .body(Default::default()) + .unwrap() + } + }); + + let url = format!("http://{}/{}", redirect.addr(), code); + let dst = format!("http://{}/{}", redirect.addr(), "dst"); + + let default_redirect_client = reqwest::Client::new(); + let res = default_redirect_client + .request(reqwest::Method::POST, &url) + .redirect(reqwest::redirect::Policy::none()) + .send() + .await + .unwrap(); + + assert_eq!(res.url().as_str(), url); + assert_eq!(res.status(), reqwest::StatusCode::MOVED_PERMANENTLY); + assert_eq!( + res.headers().get(reqwest::header::SERVER).unwrap(), + &"test-redirect" + ); + + let no_redirect_client = reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap(); + + let res = no_redirect_client + .request(reqwest::Method::POST, &url) + .redirect(reqwest::redirect::Policy::limited(2)) + .send() + .await + .unwrap(); + + assert_eq!(res.url().as_str(), dst); + assert_eq!(res.status(), reqwest::StatusCode::OK); + assert_eq!( + res.headers().get(reqwest::header::SERVER).unwrap(), + &"test-dst" + ); +}