From 668fff2e04fe66cf52e605215f43a9949dedece9 Mon Sep 17 00:00:00 2001 From: WinLinux1028 Date: Sat, 18 Nov 2023 05:26:48 +0900 Subject: [PATCH 1/3] supporting custom proxy protocol --- Cargo.toml | 2 + examples/custom_proxy_protocol.rs | 88 ++++++++++++++++++++ src/connect.rs | 129 ++++++++++++++++++------------ src/lib.rs | 2 +- src/proxy.rs | 126 ++++++++++++++++++++++++++++- 5 files changed, 290 insertions(+), 57 deletions(-) create mode 100644 examples/custom_proxy_protocol.rs diff --git a/Cargo.toml b/Cargo.toml index 1e517f824..77d06ef1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -87,6 +87,8 @@ bytes = "1.0" serde = "1.0" serde_urlencoded = "0.7.1" tower-service = "0.3" +async-trait = "0.1" +dyn-clone = "1" futures-core = { version = "0.3.0", default-features = false } futures-util = { version = "0.3.0", default-features = false } diff --git a/examples/custom_proxy_protocol.rs b/examples/custom_proxy_protocol.rs new file mode 100644 index 000000000..c0494d7e6 --- /dev/null +++ b/examples/custom_proxy_protocol.rs @@ -0,0 +1,88 @@ +use std::{error::Error, io::Write, pin::pin}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpStream, +}; + +use async_trait::async_trait; +use http::Uri; +use reqwest::{AsyncStream, Client, CustomProxyProtocol, Proxy}; + +#[tokio::main] +async fn main() { + let proxy: Box = Box::new(Example()); + let client = Client::builder() + .proxy(Proxy::all(proxy).unwrap()) + .http1_only() + .build() + .unwrap(); + let response = client + .get("http://www.hal.ipc.i.u-tokyo.ac.jp/~nakada/prog2015/alice.txt") + .send() + .await + .unwrap(); + let body = response.bytes().await.unwrap(); + + let mut stdout = std::io::stdout(); + stdout.write_all(&body).unwrap(); + stdout.flush().unwrap(); +} + +#[derive(Clone)] +struct Example(); +#[async_trait] +impl CustomProxyProtocol for Example { + async fn connect( + &self, + dst: Uri, + ) -> Result, Box> { + let host = dst.host().ok_or("host is None")?; + let port = match (dst.scheme_str(), dst.port_u16()) { + (_, Some(p)) => p, + (Some("http"), None) => 80, + (Some("https"), None) => 443, + _ => return Err("scheme is unknown and port is None.".into()), + }; + eprintln!("Connecting to {}:{}", host, port); + Ok(Box::new(WrapStream( + TcpStream::connect(format!("{}:{}", host, port)).await?, + ))) + } +} + +struct WrapStream(RW); +impl AsyncRead for WrapStream { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + eprintln!("read"); + pin!(&mut self.0).poll_read(cx, buf) + } +} +impl AsyncWrite for WrapStream { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + eprintln!("write"); + std::io::stderr().write_all(buf).unwrap(); + pin!(&mut self.0).poll_write(cx, buf) + } + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + eprintln!("flush"); + pin!(&mut self.0).poll_flush(cx) + } + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + eprintln!("shutdown"); + pin!(&mut self.0).poll_shutdown(cx) + } +} diff --git a/src/connect.rs b/src/connect.rs index c171dd18d..526ee5b30 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -23,7 +23,7 @@ use self::native_tls_conn::NativeTlsConn; use self::rustls_tls_conn::RustlsTlsConn; use crate::dns::DynResolver; use crate::error::BoxError; -use crate::proxy::{Proxy, ProxyScheme}; +use crate::proxy::{AsyncStreamWrapper, Proxy, ProxyScheme}; pub(crate) type HttpConnector = hyper::client::HttpConnector; @@ -179,7 +179,7 @@ impl Connector { ProxyScheme::Socks5 { remote_dns: true, .. } => socks::DnsResolve::Proxy, - ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => { + _ => { unreachable!("connect_socks is only called for socks proxies"); } }; @@ -319,6 +319,56 @@ impl Connector { let (proxy_dst, _auth) = match proxy_scheme { ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth), ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth), + ProxyScheme::Custom(ref p) => { + let p = p.clone(); + match &self.inner { + #[cfg(feature = "default-tls")] + Inner::DefaultTls(_http, tls) => { + if dst.scheme() == Some(&Scheme::HTTPS) { + let host = dst.host().ok_or("no host in url")?.to_string(); + let conn = p.connect(dst).await?; + let conn = AsyncStreamWrapper::from(conn); + let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); + let io = tls_connector.connect(&host, conn).await?; + return Ok(Conn { + inner: self.verbose.wrap(NativeTlsConn { inner: io }), + is_proxy: false, + tls_info: self.tls_info, + }); + } + } + #[cfg(feature = "__rustls")] + Inner::RustlsTls { tls_proxy, .. } => { + if dst.scheme() == Some(&Scheme::HTTPS) { + use std::convert::TryFrom; + use tokio_rustls::TlsConnector as RustlsConnector; + + let tls = tls_proxy.clone(); + let host = dst.host().ok_or("no host in url")?.to_string(); + let conn = p.connect(dst).await?; + let conn = AsyncStreamWrapper::from(conn); + let server_name = rustls::ServerName::try_from(host.as_str()) + .map_err(|_| "Invalid Server Name")?; + let io = RustlsConnector::from(tls) + .connect(server_name, conn) + .await?; + return Ok(Conn { + inner: self.verbose.wrap(RustlsTlsConn { inner: io }), + is_proxy: false, + tls_info: false, + }); + } + } + #[cfg(not(feature = "__tls"))] + Inner::Http(_) => (), + } + + return p.connect(dst).await.map(|tcp| Conn { + inner: self.verbose.wrap(AsyncStreamWrapper::from(tcp)), + is_proxy: false, + tls_info: false, + }); + } #[cfg(feature = "socks")] ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await, }; @@ -466,6 +516,13 @@ trait TlsInfoFactory { fn tls_info(&self) -> Option; } +#[cfg(feature = "__tls")] +impl TlsInfoFactory for AsyncStreamWrapper { + fn tls_info(&self) -> Option { + None + } +} + #[cfg(feature = "__tls")] impl TlsInfoFactory for tokio::net::TcpStream { fn tls_info(&self) -> Option { @@ -474,7 +531,10 @@ impl TlsInfoFactory for tokio::net::TcpStream { } #[cfg(feature = "default-tls")] -impl TlsInfoFactory for hyper_tls::MaybeHttpsStream { +impl TlsInfoFactory for hyper_tls::MaybeHttpsStream +where + RW: AsyncRead + AsyncWrite + Unpin, +{ fn tls_info(&self) -> Option { match self { hyper_tls::MaybeHttpsStream::Https(tls) => tls.tls_info(), @@ -484,20 +544,10 @@ impl TlsInfoFactory for hyper_tls::MaybeHttpsStream { } #[cfg(feature = "default-tls")] -impl TlsInfoFactory for hyper_tls::TlsStream> { - fn tls_info(&self) -> Option { - let peer_certificate = self - .get_ref() - .peer_certificate() - .ok() - .flatten() - .and_then(|c| c.to_der().ok()); - Some(crate::tls::TlsInfo { peer_certificate }) - } -} - -#[cfg(feature = "default-tls")] -impl TlsInfoFactory for tokio_native_tls::TlsStream { +impl TlsInfoFactory for tokio_native_tls::TlsStream +where + RW: AsyncRead + AsyncWrite + Unpin, +{ fn tls_info(&self) -> Option { let peer_certificate = self .get_ref() @@ -510,7 +560,7 @@ impl TlsInfoFactory for tokio_native_tls::TlsStream { } #[cfg(feature = "__rustls")] -impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream { +impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream { fn tls_info(&self) -> Option { match self { hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(), @@ -520,22 +570,7 @@ impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream { } #[cfg(feature = "__rustls")] -impl TlsInfoFactory for tokio_rustls::TlsStream { - fn tls_info(&self) -> Option { - let peer_certificate = self - .get_ref() - .1 - .peer_certificates() - .and_then(|certs| certs.first()) - .map(|c| c.0.clone()); - Some(crate::tls::TlsInfo { peer_certificate }) - } -} - -#[cfg(feature = "__rustls")] -impl TlsInfoFactory - for tokio_rustls::client::TlsStream> -{ +impl TlsInfoFactory for tokio_rustls::TlsStream { fn tls_info(&self) -> Option { let peer_certificate = self .get_ref() @@ -548,7 +583,7 @@ impl TlsInfoFactory } #[cfg(feature = "__rustls")] -impl TlsInfoFactory for tokio_rustls::client::TlsStream { +impl TlsInfoFactory for tokio_rustls::client::TlsStream { fn tls_info(&self) -> Option { let peer_certificate = self .get_ref() @@ -561,11 +596,10 @@ impl TlsInfoFactory for tokio_rustls::client::TlsStream { } pub(crate) trait AsyncConn: - AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static + AsyncRead + AsyncWrite + Connection + Send + Unpin + 'static { } - -impl AsyncConn for T {} +impl AsyncConn for T {} #[cfg(feature = "__tls")] trait AsyncConnWithInfo: AsyncConn + TlsInfoFactory {} @@ -824,13 +858,10 @@ mod native_tls_conn { } } - impl TlsInfoFactory for NativeTlsConn { - fn tls_info(&self) -> Option { - self.inner.tls_info() - } - } - - impl TlsInfoFactory for NativeTlsConn> { + impl TlsInfoFactory for NativeTlsConn + where + RW: AsyncRead + AsyncWrite + Unpin, + { fn tls_info(&self) -> Option { self.inner.tls_info() } @@ -917,13 +948,7 @@ mod rustls_tls_conn { } } - impl TlsInfoFactory for RustlsTlsConn { - fn tls_info(&self) -> Option { - self.inner.tls_info() - } - } - - impl TlsInfoFactory for RustlsTlsConn> { + impl TlsInfoFactory for RustlsTlsConn { fn tls_info(&self) -> Option { self.inner.tls_info() } diff --git a/src/lib.rs b/src/lib.rs index 188ba4f02..bded6f23f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -326,7 +326,7 @@ if_hyper! { pub use self::async_impl::{ Body, Client, ClientBuilder, Request, RequestBuilder, Response, Upgraded, }; - pub use self::proxy::{Proxy,NoProxy}; + pub use self::proxy::{Proxy,NoProxy,CustomProxyProtocol,AsyncStream}; #[cfg(feature = "__tls")] // Re-exports, to be removed in a future release pub use tls::{Certificate, Identity}; diff --git a/src/proxy.rs b/src/proxy.rs index 6e1bfcc73..9136ec9be 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -1,11 +1,14 @@ -use std::fmt; +use std::fmt::{self, Debug}; #[cfg(feature = "socks")] use std::net::SocketAddr; use std::sync::Arc; use crate::into_url::{IntoUrl, IntoUrlSealed}; use crate::Url; +use async_trait::async_trait; +use dyn_clone::DynClone; use http::{header::HeaderValue, Uri}; +use hyper::client::connect::{Connected, Connection}; use ipnet::IpNet; use once_cell::sync::Lazy; use percent_encoding::percent_decode; @@ -13,6 +16,7 @@ use std::collections::HashMap; use std::env; use std::error::Error; use std::net::IpAddr; +use std::pin::pin; #[cfg(target_os = "macos")] use system_configuration::{ core_foundation::{ @@ -29,6 +33,7 @@ use system_configuration::{ sys::schema_definitions::kSCPropNetProxiesHTTPSPort, sys::schema_definitions::kSCPropNetProxiesHTTPSProxy, }; +use tokio::io::{AsyncRead, AsyncWrite}; #[cfg(target_os = "windows")] use winreg::enums::HKEY_CURRENT_USER; #[cfg(target_os = "windows")] @@ -96,6 +101,115 @@ pub struct NoProxy { domains: DomainMatcher, } +/// A trait for creating a trait object that implements AsyncRead and AsyncWrite. +pub trait AsyncStream: AsyncRead + AsyncWrite + Send + Unpin + 'static {} +impl AsyncStream for RW {} +pub struct AsyncStreamWrapper(Box); +impl From> for AsyncStreamWrapper { + fn from(value: Box) -> Self { + Self(value) + } +} +impl AsyncRead for AsyncStreamWrapper { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + pin!(&mut self.0).poll_read(cx, buf) + } +} +impl AsyncWrite for AsyncStreamWrapper { + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + pin!(&mut self.0).poll_flush(cx) + } + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + pin!(&mut self.0).poll_shutdown(cx) + } + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + pin!(&mut self.0).poll_write(cx, buf) + } + fn poll_write_vectored( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> std::task::Poll> { + pin!(&mut self.0).poll_write_vectored(cx, bufs) + } + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } +} +impl Debug for AsyncStreamWrapper { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "AsyncStreamWrapper") + } +} +impl Connection for AsyncStreamWrapper { + fn connected(&self) -> Connected { + Connected::new() + } +} + +/// A trait to define custom proxy protocol. +/// `Box` implements `IntoProxyScheme`. +/// # Example +/// ``` +/// use std::error::Error; +/// use tokio::net::TcpStream; +/// +/// use async_trait::async_trait; +/// use http::Uri; +/// use reqwest::{AsyncStream, CustomProxyProtocol}; +/// +/// #[derive(Clone)] +/// struct Example(); +/// #[async_trait] +/// impl CustomProxyProtocol for Example { +/// async fn connect( +/// &self, +/// dst: Uri, +/// ) -> Result, Box> { +/// let host = dst.host().ok_or("host is None")?; +/// let port = match (dst.scheme_str(), dst.port_u16()) { +/// (_, Some(p)) => p, +/// (Some("http"), None) => 80, +/// (Some("https"), None) => 443, +/// _ => return Err("scheme is unknown and port is None.".into()), +/// }; +/// eprintln!("Connecting to {}:{}", host, port); +/// Ok(Box::new( +/// TcpStream::connect(format!("{}:{}", host, port)).await?, +/// )) +/// } +/// } +/// ``` +#[async_trait] +pub trait CustomProxyProtocol: Sync + Send + DynClone { + /// Establish an OSI layer 4(ex. TCP) connection to the web server. + async fn connect( + &self, + dst: Uri, + ) -> Result, Box>; +} +dyn_clone::clone_trait_object!(CustomProxyProtocol); + +impl IntoProxyScheme for Box { + fn into_proxy_scheme(self) -> crate::Result { + Ok(ProxyScheme::Custom(self)) + } +} + /// A particular scheme used for proxying requests. /// /// For example, HTTP vs SOCKS5 @@ -109,6 +223,7 @@ pub enum ProxyScheme { auth: Option, host: http::uri::Authority, }, + Custom(Box), #[cfg(feature = "socks")] Socks5 { addr: SocketAddr, @@ -121,7 +236,6 @@ impl ProxyScheme { fn maybe_http_auth(&self) -> Option<&HeaderValue> { match self { ProxyScheme::Http { auth, .. } | ProxyScheme::Https { auth, .. } => auth.as_ref(), - #[cfg(feature = "socks")] _ => None, } } @@ -612,6 +726,7 @@ impl ProxyScheme { let header = encode_basic_auth(&username.into(), &password.into()); *auth = Some(header); } + ProxyScheme::Custom(_) => {} #[cfg(feature = "socks")] ProxyScheme::Socks5 { ref mut auth, .. } => { *auth = Some((username.into(), password.into())); @@ -631,6 +746,7 @@ impl ProxyScheme { *auth = update.clone(); } } + ProxyScheme::Custom(_) => {} #[cfg(feature = "socks")] ProxyScheme::Socks5 { .. } => {} } @@ -684,6 +800,7 @@ impl ProxyScheme { match self { ProxyScheme::Http { .. } => "http", ProxyScheme::Https { .. } => "https", + ProxyScheme::Custom(_) => "custom", #[cfg(feature = "socks")] ProxyScheme::Socks5 { .. } => "socks5", } @@ -694,6 +811,7 @@ impl ProxyScheme { match self { ProxyScheme::Http { host, .. } => host.as_str(), ProxyScheme::Https { host, .. } => host.as_str(), + ProxyScheme::Custom(_) => panic!("custom"), #[cfg(feature = "socks")] ProxyScheme::Socks5 { .. } => panic!("socks5"), } @@ -705,6 +823,7 @@ impl fmt::Debug for ProxyScheme { match self { ProxyScheme::Http { auth: _auth, host } => write!(f, "http://{}", host), ProxyScheme::Https { auth: _auth, host } => write!(f, "https://{}", host), + ProxyScheme::Custom(_) => write!(f, "custom://"), #[cfg(feature = "socks")] ProxyScheme::Socks5 { addr, @@ -1075,8 +1194,7 @@ mod tests { let (scheme, host) = match p.intercept(&url(s)).unwrap() { ProxyScheme::Http { host, .. } => ("http", host), ProxyScheme::Https { host, .. } => ("https", host), - #[cfg(feature = "socks")] - _ => panic!("intercepted as socks"), + _ => panic!("intercepted as not http or https"), }; http::Uri::builder() .scheme(scheme) From 7494c75a20aefadfbf2031703d575379a12011ed Mon Sep 17 00:00:00 2001 From: WinLinux1028 Date: Sat, 18 Nov 2023 15:34:51 +0900 Subject: [PATCH 2/3] Fixed an issue where it was not possible to connect using HTTP/2 via Custom proxy. --- src/connect.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/connect.rs b/src/connect.rs index 526ee5b30..4a5c4ae81 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -338,17 +338,17 @@ impl Connector { } } #[cfg(feature = "__rustls")] - Inner::RustlsTls { tls_proxy, .. } => { + Inner::RustlsTls { tls, .. } => { if dst.scheme() == Some(&Scheme::HTTPS) { use std::convert::TryFrom; use tokio_rustls::TlsConnector as RustlsConnector; - let tls = tls_proxy.clone(); let host = dst.host().ok_or("no host in url")?.to_string(); let conn = p.connect(dst).await?; let conn = AsyncStreamWrapper::from(conn); let server_name = rustls::ServerName::try_from(host.as_str()) .map_err(|_| "Invalid Server Name")?; + let tls = tls.clone(); let io = RustlsConnector::from(tls) .connect(server_name, conn) .await?; From f46b0fe91df55af7e27198204a623e8bcd18f4a5 Mon Sep 17 00:00:00 2001 From: WinLinux1028 Date: Sat, 18 Nov 2023 17:06:58 +0900 Subject: [PATCH 3/3] Changed so that Custom proxy can specify is_proxy --- examples/custom_proxy_protocol.rs | 32 +++++++++++++-------- src/connect.rs | 6 ++-- src/lib.rs | 2 +- src/proxy.rs | 46 ++++++++++++++++++++----------- 4 files changed, 54 insertions(+), 32 deletions(-) diff --git a/examples/custom_proxy_protocol.rs b/examples/custom_proxy_protocol.rs index c0494d7e6..ce06c10e0 100644 --- a/examples/custom_proxy_protocol.rs +++ b/examples/custom_proxy_protocol.rs @@ -6,7 +6,7 @@ use tokio::{ use async_trait::async_trait; use http::Uri; -use reqwest::{AsyncStream, Client, CustomProxyProtocol, Proxy}; +use reqwest::{AsyncStreamWrapper, Client, CustomProxyProtocol, Proxy}; #[tokio::main] async fn main() { @@ -16,15 +16,16 @@ async fn main() { .http1_only() .build() .unwrap(); - let response = client + let mut response = client .get("http://www.hal.ipc.i.u-tokyo.ac.jp/~nakada/prog2015/alice.txt") .send() .await .unwrap(); - let body = response.bytes().await.unwrap(); let mut stdout = std::io::stdout(); - stdout.write_all(&body).unwrap(); + while let Some(chunk) = response.chunk().await.unwrap() { + stdout.write_all(&chunk).unwrap(); + } stdout.flush().unwrap(); } @@ -35,7 +36,7 @@ impl CustomProxyProtocol for Example { async fn connect( &self, dst: Uri, - ) -> Result, Box> { + ) -> Result> { let host = dst.host().ok_or("host is None")?; let port = match (dst.scheme_str(), dst.port_u16()) { (_, Some(p)) => p, @@ -44,14 +45,20 @@ impl CustomProxyProtocol for Example { _ => return Err("scheme is unknown and port is None.".into()), }; eprintln!("Connecting to {}:{}", host, port); - Ok(Box::new(WrapStream( - TcpStream::connect(format!("{}:{}", host, port)).await?, - ))) + Ok(AsyncStreamWrapper::new( + WrapStream(TcpStream::connect(format!("{}:{}", host, port)).await?), + false, + )) } } -struct WrapStream(RW); -impl AsyncRead for WrapStream { +struct WrapStream(RW) +where + RW: AsyncRead + AsyncWrite + Send + Unpin + 'static; +impl AsyncRead for WrapStream +where + RW: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ fn poll_read( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -61,7 +68,10 @@ impl AsyncRead for WrapStream { pin!(&mut self.0).poll_read(cx, buf) } } -impl AsyncWrite for WrapStream { +impl AsyncWrite for WrapStream +where + RW: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ fn poll_write( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, diff --git a/src/connect.rs b/src/connect.rs index 4a5c4ae81..5840e3581 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -327,7 +327,6 @@ impl Connector { if dst.scheme() == Some(&Scheme::HTTPS) { let host = dst.host().ok_or("no host in url")?.to_string(); let conn = p.connect(dst).await?; - let conn = AsyncStreamWrapper::from(conn); let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); let io = tls_connector.connect(&host, conn).await?; return Ok(Conn { @@ -345,7 +344,6 @@ impl Connector { let host = dst.host().ok_or("no host in url")?.to_string(); let conn = p.connect(dst).await?; - let conn = AsyncStreamWrapper::from(conn); let server_name = rustls::ServerName::try_from(host.as_str()) .map_err(|_| "Invalid Server Name")?; let tls = tls.clone(); @@ -364,8 +362,8 @@ impl Connector { } return p.connect(dst).await.map(|tcp| Conn { - inner: self.verbose.wrap(AsyncStreamWrapper::from(tcp)), - is_proxy: false, + is_proxy: tcp.is_http_proxy, + inner: self.verbose.wrap(tcp), tls_info: false, }); } diff --git a/src/lib.rs b/src/lib.rs index bded6f23f..eafc6d154 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -326,7 +326,7 @@ if_hyper! { pub use self::async_impl::{ Body, Client, ClientBuilder, Request, RequestBuilder, Response, Upgraded, }; - pub use self::proxy::{Proxy,NoProxy,CustomProxyProtocol,AsyncStream}; + pub use self::proxy::{Proxy,NoProxy,CustomProxyProtocol,AsyncStreamWrapper}; #[cfg(feature = "__tls")] // Re-exports, to be removed in a future release pub use tls::{Certificate, Identity}; diff --git a/src/proxy.rs b/src/proxy.rs index 9136ec9be..6980e41e5 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -101,13 +101,26 @@ pub struct NoProxy { domains: DomainMatcher, } -/// A trait for creating a trait object that implements AsyncRead and AsyncWrite. pub trait AsyncStream: AsyncRead + AsyncWrite + Send + Unpin + 'static {} impl AsyncStream for RW {} -pub struct AsyncStreamWrapper(Box); -impl From> for AsyncStreamWrapper { - fn from(value: Box) -> Self { - Self(value) +/// A wrapper for proxy connections and related information. +/// return type of [CustomProxyProtocol::connect]. +pub struct AsyncStreamWrapper { + pub(crate) inner: Box, + pub(crate) is_http_proxy: bool, +} +impl AsyncStreamWrapper { + /// Make a new instance of [AsyncStreamWrapper]. + /// If is_http_proxy is set to true, the connection will be treated as a connection to an http proxy. + /// This does not affect https. + pub fn new(stream: RW, is_http_proxy: bool) -> Self + where + RW: AsyncRead + AsyncWrite + Send + Unpin + 'static, + { + Self { + inner: Box::new(stream), + is_http_proxy, + } } } impl AsyncRead for AsyncStreamWrapper { @@ -116,7 +129,7 @@ impl AsyncRead for AsyncStreamWrapper { cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> std::task::Poll> { - pin!(&mut self.0).poll_read(cx, buf) + pin!(&mut self.inner).poll_read(cx, buf) } } impl AsyncWrite for AsyncStreamWrapper { @@ -124,30 +137,30 @@ impl AsyncWrite for AsyncStreamWrapper { mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - pin!(&mut self.0).poll_flush(cx) + pin!(&mut self.inner).poll_flush(cx) } fn poll_shutdown( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - pin!(&mut self.0).poll_shutdown(cx) + pin!(&mut self.inner).poll_shutdown(cx) } fn poll_write( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8], ) -> std::task::Poll> { - pin!(&mut self.0).poll_write(cx, buf) + pin!(&mut self.inner).poll_write(cx, buf) } fn poll_write_vectored( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, bufs: &[std::io::IoSlice<'_>], ) -> std::task::Poll> { - pin!(&mut self.0).poll_write_vectored(cx, bufs) + pin!(&mut self.inner).poll_write_vectored(cx, bufs) } fn is_write_vectored(&self) -> bool { - self.0.is_write_vectored() + self.inner.is_write_vectored() } } impl Debug for AsyncStreamWrapper { @@ -170,7 +183,7 @@ impl Connection for AsyncStreamWrapper { /// /// use async_trait::async_trait; /// use http::Uri; -/// use reqwest::{AsyncStream, CustomProxyProtocol}; +/// use reqwest::{AsyncStreamWrapper, CustomProxyProtocol}; /// /// #[derive(Clone)] /// struct Example(); @@ -179,7 +192,7 @@ impl Connection for AsyncStreamWrapper { /// async fn connect( /// &self, /// dst: Uri, -/// ) -> Result, Box> { +/// ) -> Result> { /// let host = dst.host().ok_or("host is None")?; /// let port = match (dst.scheme_str(), dst.port_u16()) { /// (_, Some(p)) => p, @@ -188,19 +201,20 @@ impl Connection for AsyncStreamWrapper { /// _ => return Err("scheme is unknown and port is None.".into()), /// }; /// eprintln!("Connecting to {}:{}", host, port); -/// Ok(Box::new( +/// Ok(AsyncStreamWrapper::new( /// TcpStream::connect(format!("{}:{}", host, port)).await?, +/// false, /// )) /// } /// } /// ``` #[async_trait] pub trait CustomProxyProtocol: Sync + Send + DynClone { - /// Establish an OSI layer 4(ex. TCP) connection to the web server. + /// Establish an TCP connection to the web server. async fn connect( &self, dst: Uri, - ) -> Result, Box>; + ) -> Result>; } dyn_clone::clone_trait_object!(CustomProxyProtocol);