Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting custom proxy protocol #2042

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ bytes = "1.0"
serde = "1.0"
serde_urlencoded = "0.7.1"
tower-service = "0.3"
async-trait = "0.1"
dyn-clone = "1"
futures-core = { version = "0.3.0", default-features = false }
futures-util = { version = "0.3.0", default-features = false }

Expand Down
98 changes: 98 additions & 0 deletions examples/custom_proxy_protocol.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use std::{error::Error, io::Write, pin::pin};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
};

use async_trait::async_trait;
use http::Uri;
use reqwest::{AsyncStreamWrapper, Client, CustomProxyProtocol, Proxy};

#[tokio::main]
async fn main() {
let proxy: Box<dyn CustomProxyProtocol> = Box::new(Example());
let client = Client::builder()
.proxy(Proxy::all(proxy).unwrap())
.http1_only()
.build()
.unwrap();
let mut response = client
.get("http://www.hal.ipc.i.u-tokyo.ac.jp/~nakada/prog2015/alice.txt")
.send()
.await
.unwrap();

let mut stdout = std::io::stdout();
while let Some(chunk) = response.chunk().await.unwrap() {
stdout.write_all(&chunk).unwrap();
}
stdout.flush().unwrap();
}

#[derive(Clone)]
struct Example();
#[async_trait]
impl CustomProxyProtocol for Example {
async fn connect(
&self,
dst: Uri,
) -> Result<AsyncStreamWrapper, Box<dyn Error + Send + Sync + 'static>> {
let host = dst.host().ok_or("host is None")?;
let port = match (dst.scheme_str(), dst.port_u16()) {
(_, Some(p)) => p,
(Some("http"), None) => 80,
(Some("https"), None) => 443,
_ => return Err("scheme is unknown and port is None.".into()),
};
eprintln!("Connecting to {}:{}", host, port);
Ok(AsyncStreamWrapper::new(
WrapStream(TcpStream::connect(format!("{}:{}", host, port)).await?),
false,
))
}
}

struct WrapStream<RW>(RW)
where
RW: AsyncRead + AsyncWrite + Send + Unpin + 'static;
impl<RW> AsyncRead for WrapStream<RW>
where
RW: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
eprintln!("read");
pin!(&mut self.0).poll_read(cx, buf)
}
}
impl<RW> AsyncWrite for WrapStream<RW>
where
RW: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
eprintln!("write");
std::io::stderr().write_all(buf).unwrap();
pin!(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
eprintln!("flush");
pin!(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
eprintln!("shutdown");
pin!(&mut self.0).poll_shutdown(cx)
}
}
127 changes: 75 additions & 52 deletions src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use self::native_tls_conn::NativeTlsConn;
use self::rustls_tls_conn::RustlsTlsConn;
use crate::dns::DynResolver;
use crate::error::BoxError;
use crate::proxy::{Proxy, ProxyScheme};
use crate::proxy::{AsyncStreamWrapper, Proxy, ProxyScheme};

pub(crate) type HttpConnector = hyper::client::HttpConnector<DynResolver>;

Expand Down Expand Up @@ -179,7 +179,7 @@ impl Connector {
ProxyScheme::Socks5 {
remote_dns: true, ..
} => socks::DnsResolve::Proxy,
ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => {
_ => {
unreachable!("connect_socks is only called for socks proxies");
}
};
Expand Down Expand Up @@ -319,6 +319,54 @@ impl Connector {
let (proxy_dst, _auth) = match proxy_scheme {
ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth),
ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth),
ProxyScheme::Custom(ref p) => {
let p = p.clone();
match &self.inner {
#[cfg(feature = "default-tls")]
Inner::DefaultTls(_http, tls) => {
if dst.scheme() == Some(&Scheme::HTTPS) {
let host = dst.host().ok_or("no host in url")?.to_string();
let conn = p.connect(dst).await?;
let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
let io = tls_connector.connect(&host, conn).await?;
return Ok(Conn {
inner: self.verbose.wrap(NativeTlsConn { inner: io }),
is_proxy: false,
tls_info: self.tls_info,
});
}
}
#[cfg(feature = "__rustls")]
Inner::RustlsTls { tls, .. } => {
if dst.scheme() == Some(&Scheme::HTTPS) {
use std::convert::TryFrom;
use tokio_rustls::TlsConnector as RustlsConnector;

let host = dst.host().ok_or("no host in url")?.to_string();
let conn = p.connect(dst).await?;
let server_name = rustls::ServerName::try_from(host.as_str())
.map_err(|_| "Invalid Server Name")?;
let tls = tls.clone();
let io = RustlsConnector::from(tls)
.connect(server_name, conn)
.await?;
return Ok(Conn {
inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
is_proxy: false,
tls_info: false,
});
}
}
#[cfg(not(feature = "__tls"))]
Inner::Http(_) => (),
}

return p.connect(dst).await.map(|tcp| Conn {
is_proxy: tcp.is_http_proxy,
inner: self.verbose.wrap(tcp),
tls_info: false,
});
}
#[cfg(feature = "socks")]
ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await,
};
Expand Down Expand Up @@ -466,6 +514,13 @@ trait TlsInfoFactory {
fn tls_info(&self) -> Option<crate::tls::TlsInfo>;
}

#[cfg(feature = "__tls")]
impl TlsInfoFactory for AsyncStreamWrapper {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
None
}
}

#[cfg(feature = "__tls")]
impl TlsInfoFactory for tokio::net::TcpStream {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
Expand All @@ -474,7 +529,10 @@ impl TlsInfoFactory for tokio::net::TcpStream {
}

#[cfg(feature = "default-tls")]
impl TlsInfoFactory for hyper_tls::MaybeHttpsStream<tokio::net::TcpStream> {
impl<RW> TlsInfoFactory for hyper_tls::MaybeHttpsStream<RW>
where
RW: AsyncRead + AsyncWrite + Unpin,
{
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
match self {
hyper_tls::MaybeHttpsStream::Https(tls) => tls.tls_info(),
Expand All @@ -484,20 +542,10 @@ impl TlsInfoFactory for hyper_tls::MaybeHttpsStream<tokio::net::TcpStream> {
}

#[cfg(feature = "default-tls")]
impl TlsInfoFactory for hyper_tls::TlsStream<hyper_tls::MaybeHttpsStream<tokio::net::TcpStream>> {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
let peer_certificate = self
.get_ref()
.peer_certificate()
.ok()
.flatten()
.and_then(|c| c.to_der().ok());
Some(crate::tls::TlsInfo { peer_certificate })
}
}

#[cfg(feature = "default-tls")]
impl TlsInfoFactory for tokio_native_tls::TlsStream<tokio::net::TcpStream> {
impl<RW> TlsInfoFactory for tokio_native_tls::TlsStream<RW>
where
RW: AsyncRead + AsyncWrite + Unpin,
{
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
let peer_certificate = self
.get_ref()
Expand All @@ -510,7 +558,7 @@ impl TlsInfoFactory for tokio_native_tls::TlsStream<tokio::net::TcpStream> {
}

#[cfg(feature = "__rustls")]
impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream<tokio::net::TcpStream> {
impl<RW> TlsInfoFactory for hyper_rustls::MaybeHttpsStream<RW> {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
match self {
hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(),
Expand All @@ -520,22 +568,7 @@ impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream<tokio::net::TcpStream> {
}

#[cfg(feature = "__rustls")]
impl TlsInfoFactory for tokio_rustls::TlsStream<tokio::net::TcpStream> {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
let peer_certificate = self
.get_ref()
.1
.peer_certificates()
.and_then(|certs| certs.first())
.map(|c| c.0.clone());
Some(crate::tls::TlsInfo { peer_certificate })
}
}

#[cfg(feature = "__rustls")]
impl TlsInfoFactory
for tokio_rustls::client::TlsStream<hyper_rustls::MaybeHttpsStream<tokio::net::TcpStream>>
{
impl<RW> TlsInfoFactory for tokio_rustls::TlsStream<RW> {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
let peer_certificate = self
.get_ref()
Expand All @@ -548,7 +581,7 @@ impl TlsInfoFactory
}

#[cfg(feature = "__rustls")]
impl TlsInfoFactory for tokio_rustls::client::TlsStream<tokio::net::TcpStream> {
impl<RW> TlsInfoFactory for tokio_rustls::client::TlsStream<RW> {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
let peer_certificate = self
.get_ref()
Expand All @@ -561,11 +594,10 @@ impl TlsInfoFactory for tokio_rustls::client::TlsStream<tokio::net::TcpStream> {
}

pub(crate) trait AsyncConn:
AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static
AsyncRead + AsyncWrite + Connection + Send + Unpin + 'static
{
}

impl<T: AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static> AsyncConn for T {}
impl<T: AsyncRead + AsyncWrite + Connection + Send + Unpin + 'static> AsyncConn for T {}

#[cfg(feature = "__tls")]
trait AsyncConnWithInfo: AsyncConn + TlsInfoFactory {}
Expand Down Expand Up @@ -824,13 +856,10 @@ mod native_tls_conn {
}
}

impl TlsInfoFactory for NativeTlsConn<tokio::net::TcpStream> {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
self.inner.tls_info()
}
}

impl TlsInfoFactory for NativeTlsConn<hyper_tls::MaybeHttpsStream<tokio::net::TcpStream>> {
impl<RW> TlsInfoFactory for NativeTlsConn<RW>
where
RW: AsyncRead + AsyncWrite + Unpin,
{
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
self.inner.tls_info()
}
Expand Down Expand Up @@ -917,13 +946,7 @@ mod rustls_tls_conn {
}
}

impl TlsInfoFactory for RustlsTlsConn<tokio::net::TcpStream> {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
self.inner.tls_info()
}
}

impl TlsInfoFactory for RustlsTlsConn<hyper_rustls::MaybeHttpsStream<tokio::net::TcpStream>> {
impl<RW> TlsInfoFactory for RustlsTlsConn<RW> {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
self.inner.tls_info()
}
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ if_hyper! {
pub use self::async_impl::{
Body, Client, ClientBuilder, Request, RequestBuilder, Response, Upgraded,
};
pub use self::proxy::{Proxy,NoProxy};
pub use self::proxy::{Proxy,NoProxy,CustomProxyProtocol,AsyncStreamWrapper};
#[cfg(feature = "__tls")]
// Re-exports, to be removed in a future release
pub use tls::{Certificate, Identity};
Expand Down
Loading